import os, json, argparse, math, time, csv, sys from typing import List, Dict, Tuple import numpy as np import faiss # ---------- IO ---------- def load_chunks(p): with open(p, "r", encoding="utf-8") as f: chunks = json.load(f) for c in chunks: if "text" not in c or "file" not in c or "chunk_index" not in c: raise ValueError("Each chunk must have 'file', 'chunk_index', 'text'") c.setdefault("text", "") c["chunk_index"] = int(c["chunk_index"]) return chunks def load_qrels(p): qrels, queries = {}, {} with open(p, "r", encoding="utf-8") as f: for line in f: r = json.loads(line) qid = r["qid"] queries[qid] = r.get("query", "") rel_list = r.get("relevant") or r.get("rels") or r.get("labels") or r.get("items") if rel_list is None: raise ValueError(f"{p}: missing 'relevant' field") gains = {(x.get("file") or x.get("pdf_path") or x.get("path"), int(x.get("chunk_index") or x.get("page_index") or (int(x["slide_number"])-1))) : int(x.get("rel", 1)) for x in rel_list} qrels[qid] = gains return queries, qrels def ensure_dir(p): os.makedirs(p, exist_ok=True) # ---------- Metrics ---------- def precision_at_k(ranked, relset, k): k = min(k, len(ranked)); return (sum(1 for d in ranked[:k] if d in relset)/k) if k>0 else 0.0 def recall_at_k(ranked, relset, k): return (sum(1 for d in ranked[:k] if d in relset)/max(1,len(relset))) def average_precision(ranked, relset): if not relset: return 0.0 ap, hits = 0.0, 0 for i,d in enumerate(ranked,1): if d in relset: hits += 1; ap += hits/i return ap/max(1,len(relset)) def reciprocal_rank(ranked, relset): for i,d in enumerate(ranked,1): if d in relset: return 1.0/i return 0.0 def dcg_at_k(ranked, gains, k): dcg=0.0 for i,d in enumerate(ranked[:k],1): g=gains.get(d,0) if g>0: dcg+=(2**g-1)/math.log2(i+1) return dcg def ndcg_at_k(ranked, gains, k): dcg=dcg_at_k(ranked,gains,k) ideal=sorted(gains.values(), reverse=True) idcg=0.0 for i,g in enumerate(ideal[:k],1): idcg+=(2**g-1)/math.log2(i+1) return dcg/idcg if idcg>0 else 0.0 def evaluate_run(run: Dict[str, List[Tuple[str,int]]], qrels: Dict[str, Dict[Tuple[str,int], int]], k_vals=(1,3,5,10)): agg = {f"P@{k}":0.0 for k in k_vals} | {f"R@{k}":0.0 for k in k_vals} | {f"nDCG@{k}":0.0 for k in k_vals} agg["MAP"] = 0.0; agg["MRR"] = 0.0 N = 0 for qid, gains in qrels.items(): ranked = run.get(qid, []) relset = {d for d,g in gains.items() if g > 0} for k in k_vals: agg[f"P@{k}"] += precision_at_k(ranked, relset, k) agg[f"R@{k}"] += recall_at_k(ranked, relset, k) agg[f"nDCG@{k}"] += ndcg_at_k(ranked, gains, k) agg["MAP"] += average_precision(ranked, relset) agg["MRR"] += reciprocal_rank(ranked, relset) N += 1 for m in agg: agg[m] /= max(N,1) return agg # ---------- Formatting ---------- def model_family(model_id: str) -> str: mid = model_id.lower() if "e5" in mid and "multilingual" not in mid: return "e5" # English E5 if "bge" in mid: return "bge" # BGE English return "plain" # MPNet, GTE, MiniLM, multi-qa-* def format_passages(texts: List[str], family: str) -> List[str]: if family == "e5": return [f"passage: {t.strip()}" for t in texts] if family == "bge": return [f"Represent this document for retrieval: {t.strip()}" for t in texts] return [t if isinstance(t, str) else str(t) for t in texts] def format_queries(texts: List[str], family: str) -> List[str]: if family == "e5": return [f"query: {t.strip()}" for t in texts] if family == "bge": return [f"Represent this query for retrieval: {t.strip()}" for t in texts] return [t.strip() for t in texts] # ---------- FAISS ---------- def build_index(embs: np.ndarray) -> faiss.IndexFlatIP: X = embs.astype("float32") faiss.normalize_L2(X) idx = faiss.IndexFlatIP(X.shape[1]) idx.add(X) return idx # ---------- Helpers ---------- def parse_model_spec(spec: str): """ Allow 'model_id' or 'model_id::/path/to/index.faiss' Returns (model_id, faiss_override_path or None) """ if "::" in spec: m, p = spec.split("::", 1) return m.strip(), p.strip() return spec.strip(), None def load_st_model(model_id: str): from sentence_transformers import SentenceTransformer try: return SentenceTransformer(model_id) except Exception as e: print(f"[WARN] Could not load {model_id}: {e}", file=sys.stderr) return None # ---------- Main ---------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--chunks", default='./assets/lecture_chunks.json', help="lecture_chunks.json") ap.add_argument("--qrels", default='./assets/qrels_recording.jsonl', help="qrels_recording.jsonl") ap.add_argument("--out_dir", default='./') ap.add_argument("--models", nargs="+", required=True, help="List like: 'sentence-transformers/multi-qa-mpnet-base-dot-v1::/path/embeddings_enhanced.faiss' 'BAAI/bge-base-en-v1.5' ...") ap.add_argument("--k_list", nargs="+", type=int, default=[1,3,5,10,20]) ap.add_argument("--topk", type=int, default=20) ap.add_argument("--batch", type=int, default=128) args = ap.parse_args() ensure_dir(args.out_dir) chunks = load_chunks(args.chunks) queries, qrels = load_qrels(args.qrels) qids = list(qrels.keys()) # Prepare raw query texts once q_texts_raw = [queries.get(qid, "") for qid in qids] results = [] per_model_runs = {} for spec in args.models: model_id, faiss_override = parse_model_spec(spec) fam = model_family(model_id) print(f"\n===== {model_id} (family={fam}) =====") mslug = model_id.replace("/", "_").replace(":", "_") mdir = os.path.join(args.out_dir, mslug) ensure_dir(mdir) emb_path = os.path.join(mdir, "doc_embs.npy") idx_path = os.path.join(mdir, "faiss.index") # Encoder (needed for queries; for docs only if we rebuild) enc = load_st_model(model_id) if enc is None: continue # ----- Build or load index ----- index = None if faiss_override: if not os.path.exists(faiss_override): print(f"[WARN] FAISS override not found: {faiss_override}") else: index = faiss.read_index(faiss_override) if index.ntotal != len(chunks): print(f"[WARN] Override index size {index.ntotal} != chunks {len(chunks)}; ignoring override.") index = None if index is None: # use cached per-model artifacts if present if os.path.exists(emb_path) and os.path.exists(idx_path): try: doc_embs = np.load(emb_path) index = faiss.read_index(idx_path) if index.ntotal != len(chunks): index = None except Exception: index = None if index is None: corpus_texts = [c["text"] for c in chunks] corpus_fmt = format_passages(corpus_texts, fam) t0 = time.time() doc_embs = enc.encode(corpus_fmt, batch_size=args.batch, convert_to_numpy=True, normalize_embeddings=True).astype("float32") t1 = time.time() - t0 print(f"[TIME] Encoded {len(chunks)} chunks in {t1:.1f}s") index = build_index(doc_embs.copy()) faiss.write_index(index, idx_path) np.save(emb_path, doc_embs) # ----- Encode queries ----- q_fmt = format_queries(q_texts_raw, fam) t0 = time.time() q_embs = enc.encode(q_fmt, batch_size=args.batch, convert_to_numpy=True, normalize_embeddings=True).astype("float32") t1 = time.time() - t0 print(f"[TIME] Encoded {len(q_embs)} queries in {t1:.1f}s") # ----- Search ----- sims, ids = index.search(q_embs, args.topk) # Build run keyed by (file, chunk_index) run = {} for i, qid in enumerate(qids): rows = ids[i].tolist() run[qid] = [(chunks[r]["file"], int(chunks[r]["chunk_index"])) for r in rows] per_model_runs[model_id] = run # ----- Evaluate ----- k_vals = tuple(sorted(set(args.k_list))) metrics = evaluate_run(run, qrels, k_vals=k_vals) row = {"model": model_id} | {k: round(float(v), 6) for k, v in metrics.items()} results.append(row) print("[RESULT]", row) # ----- Leaderboard ----- if results: results_sorted = sorted(results, key=lambda r: (r.get("nDCG@10",0.0), r.get("R@10",0.0)), reverse=True) csv_path = os.path.join(args.out_dir, "embedding_ablation_leaderboard.csv") with open(csv_path, "w", newline="", encoding="utf-8") as f: cols = ["model"] + [k for k in results_sorted[0].keys() if k != "model"] w = csv.DictWriter(f, fieldnames=cols); w.writeheader() for r in results_sorted: w.writerow(r) print(f"\n[OK] Wrote leaderboard: {csv_path}") # Plots (top-5 for clarity) try: import matplotlib.pyplot as plt k_vals = sorted(set(args.k_list)) topN = results_sorted[:5] plt.figure(figsize=(7,5)) for r in topN: m = r["model"] M = evaluate_run(per_model_runs[m], qrels, k_vals=tuple(k_vals)) ys = [M[f"R@{k}"] for k in k_vals] plt.plot(k_vals, ys, marker="o", label=m) plt.xlabel("k"); plt.ylabel("Recall"); plt.title("RecordingRAG: Recall@k across models") plt.legend(); plt.tight_layout() plt.savefig(os.path.join(args.out_dir, "recall_at_k_models.png"), dpi=150, bbox_inches="tight"); plt.close() plt.figure(figsize=(7,5)) for r in topN: m = r["model"] M = evaluate_run(per_model_runs[m], qrels, k_vals=tuple(k_vals)) ys = [M[f"nDCG@{k}"] for k in k_vals] plt.plot(k_vals, ys, marker="o", label=m) plt.xlabel("k"); plt.ylabel("nDCG"); plt.title("RecordingRAG: nDCG@k across models") plt.legend(); plt.tight_layout() plt.savefig(os.path.join(args.out_dir, "ndcg_at_k_models.png"), dpi=150, bbox_inches="tight"); plt.close() print(f"[OK] Saved plots to {args.out_dir}") except Exception as e: print(f"[WARN] Plotting failed: {e}") else: print("[ERROR] No models ran successfully.") if __name__ == "__main__": main()