RecordingRAG/02_label_pool.py

143 lines
6.0 KiB
Python

import os, json, argparse, csv, numpy as np, faiss, random, sys
def load_chunks(path):
with open(path, "r", encoding="utf-8") as f:
chunks = json.load(f)
# minimal hygiene
for c in chunks:
c.setdefault("text", "")
if "chunk_index" not in c:
raise ValueError("chunks.json must include 'chunk_index' per entry")
if "file" not in c:
raise ValueError("chunks.json must include 'file' per entry")
return chunks
def load_queries(path):
qs = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
r = json.loads(line)
print(line)
qs.append((r["qid"], r["query"]))
return qs
def build_index(embs: np.ndarray) -> faiss.IndexFlatIP:
embs = embs.astype("float32")
faiss.normalize_L2(embs)
idx = faiss.IndexFlatIP(embs.shape[1])
idx.add(embs)
return idx
def maybe_load_index(index_path: str):
return faiss.read_index(index_path) if (index_path and os.path.exists(index_path)) else None
def ensure_dir(p): os.makedirs(p, exist_ok=True)
def detect_family(model_id: str) -> str:
mid = model_id.lower()
if "e5" in mid and "multilingual" not in mid:
return "e5" # e5-base-v2 (EN)
if "bge" in mid:
return "bge" # BAAI/bge-*-en*
# MPNet, MiniLM, GTE, etc. -> no special formatting
return "none"
def format_passages(texts, fmt: str):
if fmt == "e5":
return [f"passage: {t.strip()}" for t in texts]
if fmt == "bge":
return [f"Represent this document for retrieval: {t.strip()}" for t in texts]
return [t if isinstance(t, str) else str(t) for t in texts]
def format_queries(texts, fmt: str):
if fmt == "e5":
return [f"query: {t.strip()}" for t in texts]
if fmt == "bge":
return [f"Represent this query for retrieval: {t.strip()}" for t in texts]
return [t.strip() for t in texts]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--chunks", default='./lecture_chunks.json', help="chunks.json with fields: file, chunk_index, text")
ap.add_argument("--queries", default='./queries.jsonl', help="queries.jsonl with {qid, query}")
ap.add_argument("--out_dir", default='./', help="Where to write label_pool_recording.csv and index if built")
ap.add_argument("--faiss_index", default="./out/embeddings_v0.2_enhanced.faiss", help="Optional existing FAISS index for chunks; if missing, build one")
ap.add_argument("--model_id", default="sentence-transformers/all-mpnet-base-v2",
help="e.g., sentence-transformers/all-mpnet-base-v2, BAAI/bge-base-en-v1.5, Alibaba-NLP/gte-base-en-v1.5, intfloat/e5-base-v2")
ap.add_argument("--format", default="none", choices=["auto","none","e5","bge"],
help="Query/passage prompt formatting. 'auto' infers from model_id.")
ap.add_argument("--topM", type=int, default=20) # How many candidate chunks are retreived for a query
ap.add_argument("--batch", type=int, default=128)
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
random.seed(args.seed); np.random.seed(args.seed)
os.makedirs(args.out_dir, exist_ok=True)
chunks = load_chunks(args.chunks)
print(args.queries)
queries = load_queries(args.queries)
# choose formatting
fmt = args.format
if fmt == "auto":
fmt = detect_family(args.model_id)
print(f"[INFO] model_id={args.model_id} | format={fmt}")
# encoder
try:
from sentence_transformers import SentenceTransformer
enc = SentenceTransformer(args.model_id)
except Exception as e:
print(f"[ERROR] Failed to load model {args.model_id}: {e}", file=sys.stderr)
sys.exit(1)
# index: use provided or build
index = maybe_load_index(args.faiss_index)
if index is None:
passages = format_passages([c["text"] for c in chunks], fmt)
doc_embs = enc.encode(passages, batch_size=args.batch, convert_to_numpy=True,
normalize_embeddings=True).astype("float32")
index = build_index(doc_embs)
faiss_path = os.path.join(args.out_dir, "recording_text.index")
faiss.write_index(index, faiss_path)
# save docmap for safety
with open(os.path.join(args.out_dir, "recording_docmap.json"), "w", encoding="utf-8") as f:
json.dump([{"row": i, "file": c["file"], "chunk_index": int(c["chunk_index"])} for i,c in enumerate(chunks)],
f, ensure_ascii=False, indent=2)
print(f"[OK] Built index → {faiss_path}")
else:
assert index.ntotal == len(chunks), \
f"FAISS has {index.ntotal} vectors but chunks.json has {len(chunks)} entries."
# queries → embeddings
q_texts = format_queries([q for _, q in queries], fmt)
q_embs = enc.encode(q_texts, batch_size=args.batch, convert_to_numpy=True,
normalize_embeddings=True).astype("float32")
# retrieve the topM chunks based on similarity for each q_embedding.
# Gives a list of similarity scores in descening order and their
# corresponding indices in chunks.json
sims, ids = index.search(q_embs, args.topM)
# write labeling CSV
out_csv = os.path.join(args.out_dir, "label_pool_recording.csv")
with open(out_csv, "w", newline="", encoding="utf-8") as f:
w = csv.writer(f)
w.writerow(["qid","query","file","chunk_index","candidate_rank","base_score","preview","rel"])
for (qid, q), row_ids, row_sims in zip(queries, ids, sims):
seen = set()
rank = 0
for sid, sc in zip(row_ids.tolist(), row_sims.tolist()):
c = chunks[sid]
key = (c["file"], int(c["chunk_index"]))
if key in seen:
continue
seen.add(key)
rank += 1
preview = (c["text"] or "").replace("\n"," ")[:300]
w.writerow([qid, q, c["file"], int(c["chunk_index"]), rank, round(float(sc),4), preview, ""])
print(f"[OK] Wrote labeling sheet: {out_csv}")
if __name__ == "__main__":
main()