# rag_ablate_hf_bigmodels.py import os, json, math, time, csv, sys, argparse from typing import List, Dict, Tuple, Optional import numpy as np import faiss import torch # Optional torch/transformers are imported lazily in backends # ---------------------------- # IO # ---------------------------- def load_chunks(p): with open(p, "r", encoding="utf-8") as f: chunks = json.load(f) for c in chunks: if "text" not in c or "file" not in c or "chunk_index" not in c: raise ValueError("Each chunk must have 'file', 'chunk_index', 'text'") c.setdefault("text", "") c["chunk_index"] = int(c["chunk_index"]) return chunks def load_qrels(p): qrels, queries = {}, {} with open(p, "r", encoding="utf-8") as f: for line in f: r = json.loads(line) qid = r["qid"] queries[qid] = r.get("query", "") rel_list = r.get("relevant") or r.get("rels") or r.get("labels") or r.get("items") if rel_list is None: raise ValueError(f"{p}: missing 'relevant' field") gains = {} for x in rel_list: fpath = x.get("file") or x.get("pdf_path") or x.get("path") if fpath is None: raise ValueError("qrels item missing file/pdf_path/path") if "chunk_index" in x: idx = int(x["chunk_index"]) elif "page_index" in x: idx = int(x["page_index"]) elif "slide_number" in x: idx = int(x["slide_number"]) - 1 else: raise ValueError("qrels item missing chunk_index/page_index/slide_number") gains[(fpath, idx)] = int(x.get("rel", 1)) qrels[qid] = gains return queries, qrels def ensure_dir(p): os.makedirs(p, exist_ok=True) # ---------------------------- # Metrics # ---------------------------- def precision_at_k(ranked, relset, k): k = min(k, len(ranked)); return (sum(1 for d in ranked[:k] if d in relset)/k) if k>0 else 0.0 def recall_at_k(ranked, relset, k): return (sum(1 for d in ranked[:k] if d in relset)/max(1,len(relset))) def average_precision(ranked, relset): if not relset: return 0.0 ap, hits = 0.0, 0 for i,d in enumerate(ranked,1): if d in relset: hits+=1; ap+=hits/i return ap/max(1,len(relset)) def reciprocal_rank(ranked, relset): for i,d in enumerate(ranked,1): if d in relset: return 1.0/i return 0.0 def dcg_at_k(ranked, gains, k): dcg=0.0 for i,d in enumerate(ranked[:k],1): g=gains.get(d,0) if g>0: dcg+=(2**g-1)/math.log2(i+1) return dcg def ndcg_at_k(ranked, gains, k): dcg = dcg_at_k(ranked,gains,k) ideal = sorted(gains.values(), reverse=True) idcg=0.0 for i,g in enumerate(ideal[:k],1): idcg+=(2**g-1)/math.log2(i+1) return dcg/idcg if idcg>0 else 0.0 def evaluate_run(run: Dict[str, List[Tuple[str,int]]], qrels: Dict[str, Dict[Tuple[str,int], int]], k_vals=(1,3,5,10)): agg = {f"P@{k}":0.0 for k in k_vals} | {f"R@{k}":0.0 for k in k_vals} | {f"nDCG@{k}":0.0 for k in k_vals} agg["MAP"] = 0.0; agg["MRR"] = 0.0 N = 0 for qid, gains in qrels.items(): ranked = run.get(qid, []) relset = {d for d,g in gains.items() if g > 0} for k in k_vals: agg[f"P@{k}"] += precision_at_k(ranked, relset, k) agg[f"R@{k}"] += recall_at_k(ranked, relset, k) agg[f"nDCG@{k}"] += ndcg_at_k(ranked, gains, k) agg["MAP"] += average_precision(ranked, relset) agg["MRR"] += reciprocal_rank(ranked, relset) N += 1 for m in agg: agg[m] /= max(N,1) return agg # ---------------------------- # FAISS # ---------------------------- def build_index(embs: np.ndarray) -> faiss.IndexFlatIP: X = embs.astype("float32") faiss.normalize_L2(X) idx = faiss.IndexFlatIP(X.shape[1]) idx.add(X) return idx # ---------------------------- # Backends # ---------------------------- def _mean_pool(last_hidden_state, attention_mask): # last_hidden_state: [B, T, D], attention_mask: [B, T] mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) # [B, T, 1] summed = (last_hidden_state * mask).sum(dim=1) # [B, D] counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1] return summed / counts class STBackend: def __init__(self, model_id, device, fp16=False): from sentence_transformers import SentenceTransformer self.model = SentenceTransformer(model_id, device=device) self.fp16 = fp16 self.name = model_id @torch.inference_mode() def encode(self, texts: List[str], batch: int) -> np.ndarray: return self.model.encode(texts, batch_size=batch, convert_to_numpy=True, normalize_embeddings=True).astype("float32") class HFEncoderBackend: def __init__(self, model_id, device, fp16=False, cache_dir=None): from transformers import AutoModel, AutoTokenizer self.tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True, cache_dir=cache_dir) self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_dir) self.model.to(device).eval() self.device = device self.fp16 = fp16 self.name = model_id @torch.inference_mode() def encode(self, texts: List[str], batch: int, max_length: int) -> np.ndarray: outs = [] for i in range(0, len(texts), batch): t = texts[i:i+batch] enc = self.tok(t, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(self.device) with torch.autocast(self.device, enabled=self.fp16): out = self.model(**enc) last = out.last_hidden_state # [B,T,D] pooled = _mean_pool(last, enc.attention_mask) # [B,D] pooled = torch.nn.functional.normalize(pooled, dim=-1) outs.append(pooled.detach().cpu().numpy().astype("float32")) return np.concatenate(outs, axis=0) # put near other backends class NVEmbedBackend: """ Uses the model's custom .encode() (trust_remote_code=True) and applies the recommended instruction prefix for queries. """ def __init__(self, model_id, device, fp16=False, cache_dir=None): from transformers import AutoTokenizer, AutoModel import torch self.tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_dir) self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_dir) self.model.to(device).eval() self.device = device self.fp16 = fp16 self.name = model_id @torch.inference_mode() def encode_docs(self, texts, batch: int, max_length: int) -> np.ndarray: # No instruction for passages import torch, torch.nn.functional as F outs = [] for i in range(0, len(texts), batch): batch_texts = texts[i:i+batch] # model.encode returns torch.Tensor [B, D] emb = self.model.encode(batch_texts, instruction="", max_length=max_length) emb = F.normalize(emb, p=2, dim=1) outs.append(emb.detach().cpu().numpy().astype("float32")) return np.concatenate(outs, axis=0) @torch.inference_mode() def encode_queries(self, texts, batch: int, max_length: int, instruction: str) -> np.ndarray: import torch, torch.nn.functional as F outs = [] for i in range(0, len(texts), batch): batch_texts = [t.strip() for t in texts[i:i+batch]] emb = self.model.encode(batch_texts, instruction=instruction, max_length=max_length) emb = F.normalize(emb, p=2, dim=1) outs.append(emb.detach().cpu().numpy().astype("float32")) return np.concatenate(outs, axis=0) class HFDecoderBackend: # e.g., gemma-3-1b-it; we pool last hidden states from a causal LM def __init__(self, model_id, device, fp16=False, cache_dir=None): from transformers import AutoTokenizer, AutoModelForCausalLM self.tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True, cache_dir=cache_dir) self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_dir) self.model.to(device).eval() self.device = device self.fp16 = fp16 self.name = model_id @torch.inference_mode() def encode(self, texts: List[str], batch: int, max_length: int) -> np.ndarray: outs = [] for i in range(0, len(texts), batch): t = texts[i:i+batch] enc = self.tok(t, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(self.device) with torch.autocast(self.device, enabled=self.fp16): out = self.model(**enc, output_hidden_states=True) last = out.hidden_states[-1] # [B,T,D] pooled = _mean_pool(last, enc.attention_mask) # [B,D] pooled = torch.nn.functional.normalize(pooled, dim=-1) outs.append(pooled.detach().cpu().numpy().astype("float32")) return np.concatenate(outs, axis=0) # ---------------------------- # Model registry / family detection # ---------------------------- def pick_backend(model_id: str, device: str, fp16: bool, cache_dir: Optional[str]): mid = model_id.lower() # sentence-transformers models (rare in your list, but just in case) if "sentence-transformers/" in mid: return STBackend(model_id, device, fp16) # reranker (skip) if "reranker" in mid: return None # signal to skip if "google/embeddinggemma-300m" in mid: return STBackend(model_id, device, fp16) # decoder-style LLM embeddings if "gemma" in mid: return HFDecoderBackend(model_id, device, fp16, cache_dir) # default: encoder-like HF model return HFEncoderBackend(model_id, device, fp16, cache_dir) def add_formatting(texts: List[str], family: str, is_query: bool) -> List[str]: # Keep neutral by default. If you later learn family-specific instructions, add here. # Examples if desired: # if family == "qwen": prefix = "query: " if is_query else "passage: " # if family == "nv": prefix = "search_query: " if is_query else "search_document: " prefix = "" return [f"{prefix}{t.strip()}" for t in texts] def family_of(model_id: str) -> str: mid = model_id.lower() if "qwen3-embedding" in mid or "qwen" in mid: return "qwen" if "linq" in mid: return "linq" if "nv-embed" in mid or "nvidia/" in mid: return "nv" if "gemma" in mid: return "gemma" return "plain" # ---------------------------- # Main # ---------------------------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--chunks", default='./assets/lecture_chunks.json', help="lecture_chunks.json") ap.add_argument("--qrels", default='./assets/qrels_recording.jsonl', help="qrels_recording.jsonl") ap.add_argument("--out_dir", required=True) ap.add_argument("--models", nargs="+", default=[ "google/embeddinggemma-300m", "google/gemma-3-1b-it", "Qwen/Qwen3-Embedding-4B", "Linq-AI-Research/Linq-Embed-Mistral", # "nvidia/NV-Embed-v2", ]) ap.add_argument("--nv_query_instr", default="Instruct: Given a question, retrieve passages that answer the question\nQuery: ", help="Instruction prefix used for NV-Embed-v2 queries.") ap.add_argument("--topk", type=int, default=20) ap.add_argument("--k_list", nargs="+", type=int, default=[1,3,5,10,20]) ap.add_argument("--batch", type=int, default=16) ap.add_argument("--max_length", type=int, default=512) ap.add_argument("--device", default="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES", "") != "" else "cpu") ap.add_argument("--fp16", action="store_true", help="Enable autocast fp16/bf16 on GPU") ap.add_argument("--cache_dir", default=None, help="HF cache dir") args = ap.parse_args() ensure_dir(args.out_dir) chunks = load_chunks(args.chunks) queries, qrels = load_qrels(args.qrels) qids = list(qrels.keys()) q_texts_raw = [queries.get(qid, "") for qid in qids] corpus_raw = [c["text"] for c in chunks] results = [] per_model_runs = {} # lazy import torch once global torch import torch for model_id in args.models: fam = family_of(model_id) be = pick_backend(model_id, args.device, args.fp16, args.cache_dir) if be is None: print(f"\n===== {model_id} =====") print("[SKIP] Detected reranker; skipping in embedding ablation.") continue print(f"\n===== {model_id} (family={fam}) =====") mslug = model_id.replace("/", "_").replace(":", "_") mdir = os.path.join(args.out_dir, mslug) ensure_dir(mdir) # Format texts (neutral by default; edit add_formatting() to try family prompts) corpus_fmt = add_formatting(corpus_raw, fam, is_query=False) q_fmt = add_formatting(q_texts_raw, fam, is_query=True) # Encode corpus t0 = time.time() doc_embs = be.encode(corpus_fmt, batch=args.batch, max_length=args.max_length) \ if not isinstance(be, STBackend) else be.encode(corpus_fmt, batch=args.batch) t1 = time.time() - t0 print(f"[TIME] Encoded {len(corpus_fmt)} chunks in {t1/60:.2f} min") # FAISS index index = build_index(doc_embs.copy()) # Encode queries t0 = time.time() q_embs = be.encode(q_fmt, batch=args.batch, max_length=args.max_length) \ if not isinstance(be, STBackend) else be.encode(q_fmt, batch=args.batch) t1 = time.time() - t0 print(f"[TIME] Encoded {len(q_fmt)} queries in {t1:.1f}s") # Search sims, ids = index.search(q_embs, args.topk) # Build run keyed by (file, chunk_index) run = {} for i, qid in enumerate(qids): rows = ids[i].tolist() run[qid] = [(chunks[r]["file"], int(chunks[r]["chunk_index"])) for r in rows] per_model_runs[model_id] = run # Evaluate k_vals = tuple(sorted(set(args.k_list))) metrics = evaluate_run(run, qrels, k_vals=k_vals) row = {"model": model_id} | {k: round(float(v), 6) for k, v in metrics.items()} results.append(row) print("[RESULT]", row) # Save per-model artifacts (optional) try: np.save(os.path.join(mdir, "doc_embs.npy"), doc_embs.astype("float32")) faiss.write_index(index, os.path.join(mdir, "faiss.index")) except Exception as e: print(f"[WARN] Could not save embeddings/index for {model_id}: {e}") # Leaderboard if results: results_sorted = sorted(results, key=lambda r: (r.get("nDCG@10",0.0), r.get("R@10",0.0)), reverse=True) csv_path = os.path.join(args.out_dir, "embedding_ablation_leaderboard.csv") with open(csv_path, "w", newline="", encoding="utf-8") as f: cols = ["model"] + [k for k in results_sorted[0].keys() if k != "model"] w = csv.DictWriter(f, fieldnames=cols); w.writeheader() for r in results_sorted: w.writerow(r) print(f"\n[OK] Wrote leaderboard: {csv_path}") # Plots try: import matplotlib.pyplot as plt k_vals = sorted(set(args.k_list)) topN = results_sorted[:5] plt.figure(figsize=(7,5)) for r in topN: m = r["model"]; M = evaluate_run(per_model_runs[m], qrels, k_vals=tuple(k_vals)) ys = [M[f"R@{k}"] for k in k_vals] plt.plot(k_vals, ys, marker="o", label=m) plt.xlabel("k"); plt.ylabel("Recall"); plt.title("RecordingRAG: Recall@k across models") plt.legend(); plt.tight_layout() plt.savefig(os.path.join(args.out_dir, "recall_at_k_models.png"), dpi=150, bbox_inches="tight"); plt.close() plt.figure(figsize=(7,5)) for r in topN: m = r["model"]; M = evaluate_run(per_model_runs[m], qrels, k_vals=tuple(k_vals)) ys = [M[f"nDCG@{k}"] for k in k_vals] plt.plot(k_vals, ys, marker="o", label=m) plt.xlabel("k"); plt.ylabel("nDCG"); plt.title("RecordingRAG: nDCG@k across models") plt.legend(); plt.tight_layout() plt.savefig(os.path.join(args.out_dir, "ndcg_at_k_models.png"), dpi=150, bbox_inches="tight"); plt.close() print(f"[OK] Saved plots to {args.out_dir}") except Exception as e: print(f"[WARN] Plotting failed: {e}") else: print("[ERROR] No models produced results.") if __name__ == "__main__": main()