import os, json, argparse, csv, numpy as np, faiss, random, sys def load_chunks(path): with open(path, "r", encoding="utf-8") as f: chunks = json.load(f) # minimal hygiene for c in chunks: c.setdefault("text", "") if "chunk_index" not in c: raise ValueError("chunks.json must include 'chunk_index' per entry") if "file" not in c: raise ValueError("chunks.json must include 'file' per entry") return chunks def load_queries(path): qs = [] with open(path, "r", encoding="utf-8") as f: for line in f: r = json.loads(line) print(line) qs.append((r["qid"], r["query"])) return qs def build_index(embs: np.ndarray) -> faiss.IndexFlatIP: embs = embs.astype("float32") faiss.normalize_L2(embs) idx = faiss.IndexFlatIP(embs.shape[1]) idx.add(embs) return idx def maybe_load_index(index_path: str): return faiss.read_index(index_path) if (index_path and os.path.exists(index_path)) else None def ensure_dir(p): os.makedirs(p, exist_ok=True) def detect_family(model_id: str) -> str: mid = model_id.lower() if "e5" in mid and "multilingual" not in mid: return "e5" # e5-base-v2 (EN) if "bge" in mid: return "bge" # BAAI/bge-*-en* # MPNet, MiniLM, GTE, etc. -> no special formatting return "none" def format_passages(texts, fmt: str): if fmt == "e5": return [f"passage: {t.strip()}" for t in texts] if fmt == "bge": return [f"Represent this document for retrieval: {t.strip()}" for t in texts] return [t if isinstance(t, str) else str(t) for t in texts] def format_queries(texts, fmt: str): if fmt == "e5": return [f"query: {t.strip()}" for t in texts] if fmt == "bge": return [f"Represent this query for retrieval: {t.strip()}" for t in texts] return [t.strip() for t in texts] def main(): ap = argparse.ArgumentParser() ap.add_argument("--chunks", default='./lecture_chunks.json', help="chunks.json with fields: file, chunk_index, text") ap.add_argument("--queries", default='./queries.jsonl', help="queries.jsonl with {qid, query}") ap.add_argument("--out_dir", default='./', help="Where to write label_pool_recording.csv and index if built") ap.add_argument("--faiss_index", default="./out/embeddings_v0.2_enhanced.faiss", help="Optional existing FAISS index for chunks; if missing, build one") ap.add_argument("--model_id", default="sentence-transformers/all-mpnet-base-v2", help="e.g., sentence-transformers/all-mpnet-base-v2, BAAI/bge-base-en-v1.5, Alibaba-NLP/gte-base-en-v1.5, intfloat/e5-base-v2") ap.add_argument("--format", default="none", choices=["auto","none","e5","bge"], help="Query/passage prompt formatting. 'auto' infers from model_id.") ap.add_argument("--topM", type=int, default=20) # How many candidate chunks are retreived for a query ap.add_argument("--batch", type=int, default=128) ap.add_argument("--seed", type=int, default=42) args = ap.parse_args() random.seed(args.seed); np.random.seed(args.seed) os.makedirs(args.out_dir, exist_ok=True) chunks = load_chunks(args.chunks) print(args.queries) queries = load_queries(args.queries) # choose formatting fmt = args.format if fmt == "auto": fmt = detect_family(args.model_id) print(f"[INFO] model_id={args.model_id} | format={fmt}") # encoder try: from sentence_transformers import SentenceTransformer enc = SentenceTransformer(args.model_id) except Exception as e: print(f"[ERROR] Failed to load model {args.model_id}: {e}", file=sys.stderr) sys.exit(1) # index: use provided or build index = maybe_load_index(args.faiss_index) if index is None: passages = format_passages([c["text"] for c in chunks], fmt) doc_embs = enc.encode(passages, batch_size=args.batch, convert_to_numpy=True, normalize_embeddings=True).astype("float32") index = build_index(doc_embs) faiss_path = os.path.join(args.out_dir, "recording_text.index") faiss.write_index(index, faiss_path) # save docmap for safety with open(os.path.join(args.out_dir, "recording_docmap.json"), "w", encoding="utf-8") as f: json.dump([{"row": i, "file": c["file"], "chunk_index": int(c["chunk_index"])} for i,c in enumerate(chunks)], f, ensure_ascii=False, indent=2) print(f"[OK] Built index → {faiss_path}") else: assert index.ntotal == len(chunks), \ f"FAISS has {index.ntotal} vectors but chunks.json has {len(chunks)} entries." # queries → embeddings q_texts = format_queries([q for _, q in queries], fmt) q_embs = enc.encode(q_texts, batch_size=args.batch, convert_to_numpy=True, normalize_embeddings=True).astype("float32") # retrieve the topM chunks based on similarity for each q_embedding. # Gives a list of similarity scores in descening order and their # corresponding indices in chunks.json sims, ids = index.search(q_embs, args.topM) # write labeling CSV out_csv = os.path.join(args.out_dir, "label_pool_recording.csv") with open(out_csv, "w", newline="", encoding="utf-8") as f: w = csv.writer(f) w.writerow(["qid","query","file","chunk_index","candidate_rank","base_score","preview","rel"]) for (qid, q), row_ids, row_sims in zip(queries, ids, sims): seen = set() rank = 0 for sid, sc in zip(row_ids.tolist(), row_sims.tolist()): c = chunks[sid] key = (c["file"], int(c["chunk_index"])) if key in seen: continue seen.add(key) rank += 1 preview = (c["text"] or "").replace("\n"," ")[:300] w.writerow([qid, q, c["file"], int(c["chunk_index"]), rank, round(float(sc),4), preview, ""]) print(f"[OK] Wrote labeling sheet: {out_csv}") if __name__ == "__main__": main()