397 lines
17 KiB
Python
397 lines
17 KiB
Python
# rag_ablate_hf_bigmodels.py
|
|
import os, json, math, time, csv, sys, argparse
|
|
from typing import List, Dict, Tuple, Optional
|
|
import numpy as np
|
|
import faiss
|
|
import torch
|
|
|
|
# Optional torch/transformers are imported lazily in backends
|
|
|
|
# ----------------------------
|
|
# IO
|
|
# ----------------------------
|
|
def load_chunks(p):
|
|
with open(p, "r", encoding="utf-8") as f:
|
|
chunks = json.load(f)
|
|
for c in chunks:
|
|
if "text" not in c or "file" not in c or "chunk_index" not in c:
|
|
raise ValueError("Each chunk must have 'file', 'chunk_index', 'text'")
|
|
c.setdefault("text", "")
|
|
c["chunk_index"] = int(c["chunk_index"])
|
|
return chunks
|
|
|
|
def load_qrels(p):
|
|
qrels, queries = {}, {}
|
|
with open(p, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
r = json.loads(line)
|
|
qid = r["qid"]
|
|
queries[qid] = r.get("query", "")
|
|
rel_list = r.get("relevant") or r.get("rels") or r.get("labels") or r.get("items")
|
|
if rel_list is None:
|
|
raise ValueError(f"{p}: missing 'relevant' field")
|
|
gains = {}
|
|
for x in rel_list:
|
|
fpath = x.get("file") or x.get("pdf_path") or x.get("path")
|
|
if fpath is None:
|
|
raise ValueError("qrels item missing file/pdf_path/path")
|
|
if "chunk_index" in x:
|
|
idx = int(x["chunk_index"])
|
|
elif "page_index" in x:
|
|
idx = int(x["page_index"])
|
|
elif "slide_number" in x:
|
|
idx = int(x["slide_number"]) - 1
|
|
else:
|
|
raise ValueError("qrels item missing chunk_index/page_index/slide_number")
|
|
gains[(fpath, idx)] = int(x.get("rel", 1))
|
|
qrels[qid] = gains
|
|
return queries, qrels
|
|
|
|
def ensure_dir(p): os.makedirs(p, exist_ok=True)
|
|
|
|
# ----------------------------
|
|
# Metrics
|
|
# ----------------------------
|
|
def precision_at_k(ranked, relset, k):
|
|
k = min(k, len(ranked));
|
|
return (sum(1 for d in ranked[:k] if d in relset)/k) if k>0 else 0.0
|
|
|
|
def recall_at_k(ranked, relset, k):
|
|
return (sum(1 for d in ranked[:k] if d in relset)/max(1,len(relset)))
|
|
|
|
def average_precision(ranked, relset):
|
|
if not relset: return 0.0
|
|
ap, hits = 0.0, 0
|
|
for i,d in enumerate(ranked,1):
|
|
if d in relset: hits+=1; ap+=hits/i
|
|
return ap/max(1,len(relset))
|
|
|
|
def reciprocal_rank(ranked, relset):
|
|
for i,d in enumerate(ranked,1):
|
|
if d in relset: return 1.0/i
|
|
return 0.0
|
|
|
|
def dcg_at_k(ranked, gains, k):
|
|
dcg=0.0
|
|
for i,d in enumerate(ranked[:k],1):
|
|
g=gains.get(d,0)
|
|
if g>0: dcg+=(2**g-1)/math.log2(i+1)
|
|
return dcg
|
|
|
|
def ndcg_at_k(ranked, gains, k):
|
|
dcg = dcg_at_k(ranked,gains,k)
|
|
ideal = sorted(gains.values(), reverse=True)
|
|
idcg=0.0
|
|
for i,g in enumerate(ideal[:k],1):
|
|
idcg+=(2**g-1)/math.log2(i+1)
|
|
return dcg/idcg if idcg>0 else 0.0
|
|
|
|
def evaluate_run(run: Dict[str, List[Tuple[str,int]]], qrels: Dict[str, Dict[Tuple[str,int], int]], k_vals=(1,3,5,10)):
|
|
agg = {f"P@{k}":0.0 for k in k_vals} | {f"R@{k}":0.0 for k in k_vals} | {f"nDCG@{k}":0.0 for k in k_vals}
|
|
agg["MAP"] = 0.0; agg["MRR"] = 0.0
|
|
N = 0
|
|
for qid, gains in qrels.items():
|
|
ranked = run.get(qid, [])
|
|
relset = {d for d,g in gains.items() if g > 0}
|
|
for k in k_vals:
|
|
agg[f"P@{k}"] += precision_at_k(ranked, relset, k)
|
|
agg[f"R@{k}"] += recall_at_k(ranked, relset, k)
|
|
agg[f"nDCG@{k}"] += ndcg_at_k(ranked, gains, k)
|
|
agg["MAP"] += average_precision(ranked, relset)
|
|
agg["MRR"] += reciprocal_rank(ranked, relset)
|
|
N += 1
|
|
for m in agg: agg[m] /= max(N,1)
|
|
return agg
|
|
|
|
# ----------------------------
|
|
# FAISS
|
|
# ----------------------------
|
|
def build_index(embs: np.ndarray) -> faiss.IndexFlatIP:
|
|
X = embs.astype("float32")
|
|
faiss.normalize_L2(X)
|
|
idx = faiss.IndexFlatIP(X.shape[1])
|
|
idx.add(X)
|
|
return idx
|
|
|
|
# ----------------------------
|
|
# Backends
|
|
# ----------------------------
|
|
def _mean_pool(last_hidden_state, attention_mask):
|
|
# last_hidden_state: [B, T, D], attention_mask: [B, T]
|
|
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) # [B, T, 1]
|
|
summed = (last_hidden_state * mask).sum(dim=1) # [B, D]
|
|
counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1]
|
|
return summed / counts
|
|
|
|
class STBackend:
|
|
def __init__(self, model_id, device, fp16=False):
|
|
from sentence_transformers import SentenceTransformer
|
|
self.model = SentenceTransformer(model_id, device=device)
|
|
self.fp16 = fp16
|
|
self.name = model_id
|
|
@torch.inference_mode()
|
|
def encode(self, texts: List[str], batch: int) -> np.ndarray:
|
|
return self.model.encode(texts, batch_size=batch, convert_to_numpy=True, normalize_embeddings=True).astype("float32")
|
|
|
|
class HFEncoderBackend:
|
|
def __init__(self, model_id, device, fp16=False, cache_dir=None):
|
|
from transformers import AutoModel, AutoTokenizer
|
|
self.tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True, cache_dir=cache_dir)
|
|
self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_dir)
|
|
self.model.to(device).eval()
|
|
self.device = device
|
|
self.fp16 = fp16
|
|
self.name = model_id
|
|
@torch.inference_mode()
|
|
def encode(self, texts: List[str], batch: int, max_length: int) -> np.ndarray:
|
|
outs = []
|
|
for i in range(0, len(texts), batch):
|
|
t = texts[i:i+batch]
|
|
enc = self.tok(t, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(self.device)
|
|
with torch.autocast(self.device, enabled=self.fp16):
|
|
out = self.model(**enc)
|
|
last = out.last_hidden_state # [B,T,D]
|
|
pooled = _mean_pool(last, enc.attention_mask) # [B,D]
|
|
pooled = torch.nn.functional.normalize(pooled, dim=-1)
|
|
outs.append(pooled.detach().cpu().numpy().astype("float32"))
|
|
return np.concatenate(outs, axis=0)
|
|
|
|
# put near other backends
|
|
class NVEmbedBackend:
|
|
"""
|
|
Uses the model's custom .encode() (trust_remote_code=True) and applies the recommended instruction prefix for queries.
|
|
"""
|
|
def __init__(self, model_id, device, fp16=False, cache_dir=None):
|
|
from transformers import AutoTokenizer, AutoModel
|
|
import torch
|
|
self.tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_dir)
|
|
self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_dir)
|
|
self.model.to(device).eval()
|
|
self.device = device
|
|
self.fp16 = fp16
|
|
self.name = model_id
|
|
|
|
@torch.inference_mode()
|
|
def encode_docs(self, texts, batch: int, max_length: int) -> np.ndarray:
|
|
# No instruction for passages
|
|
import torch, torch.nn.functional as F
|
|
outs = []
|
|
for i in range(0, len(texts), batch):
|
|
batch_texts = texts[i:i+batch]
|
|
# model.encode returns torch.Tensor [B, D]
|
|
emb = self.model.encode(batch_texts, instruction="", max_length=max_length)
|
|
emb = F.normalize(emb, p=2, dim=1)
|
|
outs.append(emb.detach().cpu().numpy().astype("float32"))
|
|
return np.concatenate(outs, axis=0)
|
|
|
|
@torch.inference_mode()
|
|
def encode_queries(self, texts, batch: int, max_length: int, instruction: str) -> np.ndarray:
|
|
import torch, torch.nn.functional as F
|
|
outs = []
|
|
for i in range(0, len(texts), batch):
|
|
batch_texts = [t.strip() for t in texts[i:i+batch]]
|
|
emb = self.model.encode(batch_texts, instruction=instruction, max_length=max_length)
|
|
emb = F.normalize(emb, p=2, dim=1)
|
|
outs.append(emb.detach().cpu().numpy().astype("float32"))
|
|
return np.concatenate(outs, axis=0)
|
|
|
|
class HFDecoderBackend:
|
|
# e.g., gemma-3-1b-it; we pool last hidden states from a causal LM
|
|
def __init__(self, model_id, device, fp16=False, cache_dir=None):
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
self.tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True, cache_dir=cache_dir)
|
|
self.model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, cache_dir=cache_dir)
|
|
self.model.to(device).eval()
|
|
self.device = device
|
|
self.fp16 = fp16
|
|
self.name = model_id
|
|
@torch.inference_mode()
|
|
def encode(self, texts: List[str], batch: int, max_length: int) -> np.ndarray:
|
|
outs = []
|
|
for i in range(0, len(texts), batch):
|
|
t = texts[i:i+batch]
|
|
enc = self.tok(t, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to(self.device)
|
|
with torch.autocast(self.device, enabled=self.fp16):
|
|
out = self.model(**enc, output_hidden_states=True)
|
|
last = out.hidden_states[-1] # [B,T,D]
|
|
pooled = _mean_pool(last, enc.attention_mask) # [B,D]
|
|
pooled = torch.nn.functional.normalize(pooled, dim=-1)
|
|
outs.append(pooled.detach().cpu().numpy().astype("float32"))
|
|
return np.concatenate(outs, axis=0)
|
|
|
|
# ----------------------------
|
|
# Model registry / family detection
|
|
# ----------------------------
|
|
def pick_backend(model_id: str, device: str, fp16: bool, cache_dir: Optional[str]):
|
|
mid = model_id.lower()
|
|
# sentence-transformers models (rare in your list, but just in case)
|
|
if "sentence-transformers/" in mid:
|
|
return STBackend(model_id, device, fp16)
|
|
# reranker (skip)
|
|
if "reranker" in mid:
|
|
return None # signal to skip
|
|
if "google/embeddinggemma-300m" in mid:
|
|
return STBackend(model_id, device, fp16)
|
|
# decoder-style LLM embeddings
|
|
if "gemma" in mid:
|
|
return HFDecoderBackend(model_id, device, fp16, cache_dir)
|
|
# default: encoder-like HF model
|
|
return HFEncoderBackend(model_id, device, fp16, cache_dir)
|
|
|
|
def add_formatting(texts: List[str], family: str, is_query: bool) -> List[str]:
|
|
# Keep neutral by default. If you later learn family-specific instructions, add here.
|
|
# Examples if desired:
|
|
# if family == "qwen": prefix = "query: " if is_query else "passage: "
|
|
# if family == "nv": prefix = "search_query: " if is_query else "search_document: "
|
|
prefix = ""
|
|
return [f"{prefix}{t.strip()}" for t in texts]
|
|
|
|
def family_of(model_id: str) -> str:
|
|
mid = model_id.lower()
|
|
if "qwen3-embedding" in mid or "qwen" in mid: return "qwen"
|
|
if "linq" in mid: return "linq"
|
|
if "nv-embed" in mid or "nvidia/" in mid: return "nv"
|
|
if "gemma" in mid: return "gemma"
|
|
return "plain"
|
|
|
|
# ----------------------------
|
|
# Main
|
|
# ----------------------------
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--chunks", default='./assets/lecture_chunks.json', help="lecture_chunks.json")
|
|
ap.add_argument("--qrels", default='./assets/qrels_recording.jsonl', help="qrels_recording.jsonl")
|
|
ap.add_argument("--out_dir", required=True)
|
|
ap.add_argument("--models", nargs="+", default=[
|
|
"google/embeddinggemma-300m",
|
|
"google/gemma-3-1b-it",
|
|
"Qwen/Qwen3-Embedding-4B",
|
|
"Linq-AI-Research/Linq-Embed-Mistral",
|
|
# "nvidia/NV-Embed-v2",
|
|
])
|
|
ap.add_argument("--nv_query_instr", default="Instruct: Given a question, retrieve passages that answer the question\nQuery: ", help="Instruction prefix used for NV-Embed-v2 queries.")
|
|
ap.add_argument("--topk", type=int, default=20)
|
|
ap.add_argument("--k_list", nargs="+", type=int, default=[1,3,5,10,20])
|
|
ap.add_argument("--batch", type=int, default=16)
|
|
ap.add_argument("--max_length", type=int, default=512)
|
|
ap.add_argument("--device", default="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES", "") != "" else "cpu")
|
|
ap.add_argument("--fp16", action="store_true", help="Enable autocast fp16/bf16 on GPU")
|
|
ap.add_argument("--cache_dir", default=None, help="HF cache dir")
|
|
args = ap.parse_args()
|
|
|
|
ensure_dir(args.out_dir)
|
|
chunks = load_chunks(args.chunks)
|
|
queries, qrels = load_qrels(args.qrels)
|
|
qids = list(qrels.keys())
|
|
q_texts_raw = [queries.get(qid, "") for qid in qids]
|
|
corpus_raw = [c["text"] for c in chunks]
|
|
|
|
results = []
|
|
per_model_runs = {}
|
|
|
|
# lazy import torch once
|
|
global torch
|
|
import torch
|
|
|
|
for model_id in args.models:
|
|
fam = family_of(model_id)
|
|
be = pick_backend(model_id, args.device, args.fp16, args.cache_dir)
|
|
|
|
if be is None:
|
|
print(f"\n===== {model_id} =====")
|
|
print("[SKIP] Detected reranker; skipping in embedding ablation.")
|
|
continue
|
|
|
|
print(f"\n===== {model_id} (family={fam}) =====")
|
|
mslug = model_id.replace("/", "_").replace(":", "_")
|
|
mdir = os.path.join(args.out_dir, mslug)
|
|
ensure_dir(mdir)
|
|
|
|
# Format texts (neutral by default; edit add_formatting() to try family prompts)
|
|
corpus_fmt = add_formatting(corpus_raw, fam, is_query=False)
|
|
q_fmt = add_formatting(q_texts_raw, fam, is_query=True)
|
|
|
|
# Encode corpus
|
|
t0 = time.time()
|
|
doc_embs = be.encode(corpus_fmt, batch=args.batch, max_length=args.max_length) \
|
|
if not isinstance(be, STBackend) else be.encode(corpus_fmt, batch=args.batch)
|
|
t1 = time.time() - t0
|
|
print(f"[TIME] Encoded {len(corpus_fmt)} chunks in {t1/60:.2f} min")
|
|
|
|
# FAISS index
|
|
index = build_index(doc_embs.copy())
|
|
|
|
# Encode queries
|
|
t0 = time.time()
|
|
q_embs = be.encode(q_fmt, batch=args.batch, max_length=args.max_length) \
|
|
if not isinstance(be, STBackend) else be.encode(q_fmt, batch=args.batch)
|
|
t1 = time.time() - t0
|
|
print(f"[TIME] Encoded {len(q_fmt)} queries in {t1:.1f}s")
|
|
|
|
# Search
|
|
sims, ids = index.search(q_embs, args.topk)
|
|
|
|
# Build run keyed by (file, chunk_index)
|
|
run = {}
|
|
for i, qid in enumerate(qids):
|
|
rows = ids[i].tolist()
|
|
run[qid] = [(chunks[r]["file"], int(chunks[r]["chunk_index"])) for r in rows]
|
|
|
|
per_model_runs[model_id] = run
|
|
|
|
# Evaluate
|
|
k_vals = tuple(sorted(set(args.k_list)))
|
|
metrics = evaluate_run(run, qrels, k_vals=k_vals)
|
|
row = {"model": model_id} | {k: round(float(v), 6) for k, v in metrics.items()}
|
|
results.append(row)
|
|
print("[RESULT]", row)
|
|
|
|
# Save per-model artifacts (optional)
|
|
try:
|
|
np.save(os.path.join(mdir, "doc_embs.npy"), doc_embs.astype("float32"))
|
|
faiss.write_index(index, os.path.join(mdir, "faiss.index"))
|
|
except Exception as e:
|
|
print(f"[WARN] Could not save embeddings/index for {model_id}: {e}")
|
|
|
|
# Leaderboard
|
|
if results:
|
|
results_sorted = sorted(results, key=lambda r: (r.get("nDCG@10",0.0), r.get("R@10",0.0)), reverse=True)
|
|
csv_path = os.path.join(args.out_dir, "embedding_ablation_leaderboard.csv")
|
|
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
|
cols = ["model"] + [k for k in results_sorted[0].keys() if k != "model"]
|
|
w = csv.DictWriter(f, fieldnames=cols); w.writeheader()
|
|
for r in results_sorted: w.writerow(r)
|
|
print(f"\n[OK] Wrote leaderboard: {csv_path}")
|
|
|
|
# Plots
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
k_vals = sorted(set(args.k_list))
|
|
topN = results_sorted[:5]
|
|
|
|
plt.figure(figsize=(7,5))
|
|
for r in topN:
|
|
m = r["model"]; M = evaluate_run(per_model_runs[m], qrels, k_vals=tuple(k_vals))
|
|
ys = [M[f"R@{k}"] for k in k_vals]
|
|
plt.plot(k_vals, ys, marker="o", label=m)
|
|
plt.xlabel("k"); plt.ylabel("Recall"); plt.title("RecordingRAG: Recall@k across models")
|
|
plt.legend(); plt.tight_layout()
|
|
plt.savefig(os.path.join(args.out_dir, "recall_at_k_models.png"), dpi=150, bbox_inches="tight"); plt.close()
|
|
|
|
plt.figure(figsize=(7,5))
|
|
for r in topN:
|
|
m = r["model"]; M = evaluate_run(per_model_runs[m], qrels, k_vals=tuple(k_vals))
|
|
ys = [M[f"nDCG@{k}"] for k in k_vals]
|
|
plt.plot(k_vals, ys, marker="o", label=m)
|
|
plt.xlabel("k"); plt.ylabel("nDCG"); plt.title("RecordingRAG: nDCG@k across models")
|
|
plt.legend(); plt.tight_layout()
|
|
plt.savefig(os.path.join(args.out_dir, "ndcg_at_k_models.png"), dpi=150, bbox_inches="tight"); plt.close()
|
|
print(f"[OK] Saved plots to {args.out_dir}")
|
|
except Exception as e:
|
|
print(f"[WARN] Plotting failed: {e}")
|
|
else:
|
|
print("[ERROR] No models produced results.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|