RecordingRAG/04_rag_evaluate.py

179 lines
6.7 KiB
Python

import os, json, argparse, math, csv, numpy as np, faiss
# ---------- IO ----------
def load_chunks(p):
with open(p, "r", encoding="utf-8") as f:
chunks = json.load(f)
for c in chunks:
if "text" not in c or "file" not in c or "chunk_index" not in c:
raise ValueError("Each chunk must have 'file', 'chunk_index', 'text'")
c.setdefault("text", "")
c["chunk_index"] = int(c["chunk_index"])
return chunks
def load_qrels(p):
qrels, queries = {}, {}
with open(p, "r", encoding="utf-8") as f:
for line in f:
r = json.loads(line)
qid = r["qid"]
queries[qid] = r.get("query", "")
gains = {(x["file"], int(x["chunk_index"])): int(x.get("rel", 1)) for x in r["relevant"]}
qrels[qid] = gains
return queries, qrels
def ensure_dir(p): os.makedirs(p, exist_ok=True)
# ---------- Metrics ----------
def precision_at_k(ranked, relset, k):
k = min(k, len(ranked));
return (sum(1 for d in ranked[:k] if d in relset)/k) if k>0 else 0.0
def recall_at_k(ranked, relset, k):
return (sum(1 for d in ranked[:k] if d in relset)/max(1,len(relset)))
def average_precision(ranked, relset):
if not relset: return 0.0
ap, hits = 0.0, 0
for i,d in enumerate(ranked,1):
if d in relset: hits += 1; ap += hits/i
return ap/max(1,len(relset))
def reciprocal_rank(ranked, relset):
for i,d in enumerate(ranked,1):
if d in relset: return 1.0/i
return 0.0
def dcg_at_k(ranked, gains, k):
dcg=0.0
for i,d in enumerate(ranked[:k],1):
g=gains.get(d,0)
if g>0: dcg += (2**g-1)/math.log2(i+1)
return dcg
def ndcg_at_k(ranked, gains, k):
dcg = dcg_at_k(ranked, gains, k)
ideal = sorted(gains.values(), reverse=True)
idcg=0.0
for i,g in enumerate(ideal[:k],1):
idcg += (2**g-1)/math.log2(i+1)
return dcg/idcg if idcg>0 else 0.0
def evaluate(run, qrels, k_vals):
agg={f"P@{k}":0.0 for k in k_vals}|{f"R@{k}":0.0 for k in k_vals}|{f"nDCG@{k}":0.0 for k in k_vals}
agg["MAP"]=0.0; agg["MRR"]=0.0
N=0
for qid,gains in qrels.items():
ranked = run.get(qid, [])
relset = {d for d,g in gains.items() if g>0}
for k in k_vals:
agg[f"P@{k}"] += precision_at_k(ranked, relset, k)
agg[f"R@{k}"] += recall_at_k(ranked, relset, k)
agg[f"nDCG@{k}"] += ndcg_at_k(ranked, gains, k)
agg["MAP"] += average_precision(ranked, relset)
agg["MRR"] += reciprocal_rank(ranked, relset)
N += 1
for m in agg: agg[m] /= max(1,N)
return agg
# ---------- Formatting ----------
def detect_family(model_id: str) -> str:
mid = model_id.lower()
if "bge" in mid: return "bge"
if "e5" in mid and "multilingual" not in mid: return "e5"
return "none"
def format_passages(texts, fmt: str):
if fmt=="e5": return [f"passage: {t.strip()}" for t in texts]
if fmt=="bge": return [f"Represent this document for retrieval: {t.strip()}" for t in texts]
return [t if isinstance(t,str) else str(t) for t in texts]
def format_queries(texts, fmt: str):
if fmt=="e5": return [f"query: {t.strip()}" for t in texts]
if fmt=="bge": return [f"Represent this query for retrieval: {t.strip()}" for t in texts]
return [t.strip() for t in texts]
def build_index(embs: np.ndarray) -> faiss.IndexFlatIP:
X = embs.astype("float32")
faiss.normalize_L2(X)
idx = faiss.IndexFlatIP(X.shape[1]); idx.add(X)
return idx
# ---------- Main ----------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--chunks", default='./lecture_chunks.json')
ap.add_argument("--qrels", default='./qrels_recording.jsonl')
ap.add_argument("--out_dir", default='./')
ap.add_argument("--faiss_index", default='./out/embeddings_v0.2_enhanced.faiss', help="Optional prebuilt index; must match model & format")
ap.add_argument("--model_id", default="sentence-transformers/all-mpnet-base-v2")
ap.add_argument("--format", default="none", choices=["auto","none","e5","bge"])
ap.add_argument("--k_list", nargs="+", type=int, default=[1,3,5,10,20])
ap.add_argument("--topk", type=int, default=25)
ap.add_argument("--batch", type=int, default=128)
args = ap.parse_args()
ensure_dir(args.out_dir)
chunks = load_chunks(args.chunks)
queries, qrels = load_qrels(args.qrels)
qids = list(qrels.keys())
fmt = args.format if args.format!="auto" else detect_family(args.model_id)
print(f"[INFO] model={args.model_id} | format={fmt}")
# Build or load index
from sentence_transformers import SentenceTransformer
enc = SentenceTransformer(args.model_id)
if args.faiss_index and os.path.exists(args.faiss_index):
index = faiss.read_index(args.faiss_index)
assert index.ntotal == len(chunks), "Index size ≠ chunks length."
else:
passages = format_passages([c["text"] for c in chunks], fmt)
doc_embs = enc.encode(passages, batch_size=args.batch, convert_to_numpy=True,
normalize_embeddings=True).astype("float32")
index = build_index(doc_embs)
faiss.write_index(index, os.path.join(args.out_dir, "recording_text.index"))
# Encode queries
q_texts = []
qids = list(qrels.keys())
for qid in qids:
q = queries.get(qid, "")
q_texts.append(f"query: {q.strip()}" if args.format == 'e5' else q)
q_embs = enc.encode(q_texts, batch_size=128, convert_to_numpy=True, normalize_embeddings=True).astype("float32")
# Retrieve topk chunks for each query
sims, ids = index.search(q_embs, args.topk)
run = {}
for i,qid in enumerate(qids):
rows = ids[i].tolist()
run[qid] = [(chunks[r]["file"], int(chunks[r]["chunk_index"])) for r in rows]
k_vals = tuple(sorted(set(args.k_list)))
metrics = evaluate(run, qrels, k_vals)
print("[RESULT]", {k: round(v,4) for k,v in metrics.items()})
# plots
import matplotlib.pyplot as plt
def plot_curve(metric_prefix, fname):
ys = [metrics[f"{metric_prefix}{k}"] for k in k_vals]
plt.figure(figsize=(6,4))
plt.plot(list(k_vals), ys, marker="o")
plt.xlabel("k"); plt.ylabel(metric_prefix.rstrip('@'))
plt.title(f"{metric_prefix} vs k (RecordingRAG)")
plt.tight_layout()
plt.savefig(os.path.join(args.out_dir, fname), dpi=140); plt.close()
plot_curve("R@", "recall_at_k_recording.png")
plot_curve("nDCG@", "ndcg_at_k_recording.png")
# CSV
csv_path = os.path.join(args.out_dir, "retriever_metrics_recording.csv")
with open(csv_path, "w", newline="", encoding="utf-8") as f:
w = csv.writer(f); w.writerow(["metric","value"])
for k,v in metrics.items(): w.writerow([k, round(v,6)])
print(f"[OK] Saved: {csv_path}")
if __name__ == "__main__":
main()