270 lines
11 KiB
Python
270 lines
11 KiB
Python
import os, json, argparse, math, numpy as np, faiss, sys
|
|
|
|
# ---------- IO ----------
|
|
def load_chunks(p):
|
|
with open(p, "r", encoding="utf-8") as f:
|
|
chunks = json.load(f)
|
|
for c in chunks:
|
|
if "text" not in c or "file" not in c or "chunk_index" not in c:
|
|
raise ValueError("Each chunk must have 'file', 'chunk_index', 'text'")
|
|
c.setdefault("text",""); c["chunk_index"]=int(c["chunk_index"])
|
|
return chunks
|
|
|
|
def load_qrels(p):
|
|
qrels, queries = {}, {}
|
|
with open(p, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
r = json.loads(line)
|
|
qid = r["qid"]
|
|
queries[qid] = r.get("query","")
|
|
gains = {(x["file"], int(x["chunk_index"])): int(x.get("rel",1)) for x in r["relevant"]}
|
|
qrels[qid] = gains
|
|
return queries, qrels
|
|
|
|
def ensure_dir(p): os.makedirs(p, exist_ok=True)
|
|
|
|
# ---------- Metrics ----------
|
|
def precision_at_k(ranked, relset, k):
|
|
k=min(k,len(ranked));
|
|
return (sum(1 for d in ranked[:k] if d in relset)/k) if k>0 else 0.0
|
|
|
|
def recall_at_k(ranked, relset, k):
|
|
return (sum(1 for d in ranked[:k] if d in relset)/max(1,len(relset)))
|
|
|
|
def average_precision(ranked, relset):
|
|
if not relset: return 0.0
|
|
ap,hits=0.0,0
|
|
for i,d in enumerate(ranked,1):
|
|
if d in relset: hits+=1; ap+=hits/i
|
|
return ap/max(1,len(relset))
|
|
|
|
def reciprocal_rank(ranked, relset):
|
|
for i,d in enumerate(ranked,1):
|
|
if d in relset: return 1.0/i
|
|
return 0.0
|
|
|
|
def zscore(x):
|
|
x = np.asarray(x, dtype=np.float32)
|
|
mu, sd = float(x.mean()), float(x.std()) + 1e-9
|
|
return (x - mu) / sd
|
|
|
|
def dcg_at_k(ranked, gains, k):
|
|
dcg=0.0
|
|
for i,d in enumerate(ranked[:k],1):
|
|
g=gains.get(d,0)
|
|
if g>0: dcg+=(2**g-1)/math.log2(i+1)
|
|
return dcg
|
|
|
|
def ndcg_at_k(ranked, gains, k):
|
|
dcg=dcg_at_k(ranked,gains,k)
|
|
ideal=sorted(gains.values(), reverse=True)
|
|
idcg=0.0
|
|
for i,g in enumerate(ideal[:k],1):
|
|
idcg+=(2**g-1)/math.log2(i+1)
|
|
return dcg/idcg if idcg>0 else 0.0
|
|
|
|
def evaluate(run, qrels, k_vals):
|
|
agg={f"P@{k}":0.0 for k in k_vals}|{f"R@{k}":0.0 for k in k_vals}|{f"nDCG@{k}":0.0 for k in k_vals}
|
|
agg["MAP"]=0.0; agg["MRR"]=0.0
|
|
N=0
|
|
for qid,gains in qrels.items():
|
|
ranked = run.get(qid, [])
|
|
relset = {d for d,g in gains.items() if g>0}
|
|
for k in k_vals:
|
|
agg[f"P@{k}"] += precision_at_k(ranked, relset, k)
|
|
agg[f"R@{k}"] += recall_at_k(ranked, relset, k)
|
|
agg[f"nDCG@{k}"] += ndcg_at_k(ranked, gains, k)
|
|
agg["MAP"] += average_precision(ranked, relset)
|
|
agg["MRR"] += reciprocal_rank(ranked, relset)
|
|
N += 1
|
|
for m in agg: agg[m]/=max(1,N)
|
|
return agg
|
|
|
|
# ---------- Formatting ----------
|
|
def detect_family(model_id: str) -> str:
|
|
mid = model_id.lower()
|
|
if "bge" in mid: return "bge"
|
|
if "e5" in mid and "multilingual" not in mid: return "e5"
|
|
return "none"
|
|
|
|
def format_passages(texts, fmt: str):
|
|
if fmt=="e5": return [f"passage: {t.strip()}" for t in texts]
|
|
if fmt=="bge": return [f"Represent this document for retrieval: {t.strip()}" for t in texts]
|
|
return [t if isinstance(t,str) else str(t) for t in texts]
|
|
|
|
def format_queries(texts, fmt: str):
|
|
if fmt=="e5": return [f"query: {t.strip()}" for t in texts]
|
|
if fmt=="bge": return [f"Represent this query for retrieval: {t.strip()}" for t in texts]
|
|
return [t.strip() for t in texts]
|
|
|
|
def zscore(x):
|
|
x = np.asarray(x, dtype=np.float32)
|
|
mu, sd = float(x.mean()), float(x.std()) + 1e-9
|
|
return (x - mu) / sd
|
|
|
|
# ---------- Main ----------
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--chunks", default='./assets/lecture_chunks.json')
|
|
ap.add_argument("--qrels", default='./assets/qrels_recording.jsonl')
|
|
ap.add_argument("--out_dir", default='./')
|
|
ap.add_argument("--faiss_index", default='./assets/embeddings_v0.2_enhanced.faiss', help="Optional prebuilt index matching model & format")
|
|
ap.add_argument("--model_id", default="sentence-transformers/all-mpnet-base-v2")
|
|
ap.add_argument("--format", default="none", choices=["auto","none","e5","bge"])
|
|
ap.add_argument("--topk", type=int, default=5, help="final k after re-rank")
|
|
ap.add_argument("--rerank_topM", type=int, default=100, help="first-stage depth to re-rank")
|
|
ap.add_argument("--batch", type=int, default=128)
|
|
# reranker
|
|
ap.add_argument("--reranker_model", default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
|
ap.add_argument("--fuse", choices=["none","wsum"], default="none",
|
|
help="'none' = pure CE order; 'wsum' = z-scored weighted with base sims")
|
|
ap.add_argument("--gamma", type=float, default=0.7, help="weight for CE in 'wsum'")
|
|
ap.add_argument("--tune_gamma", action="store_true",
|
|
help="Grid-search gamma to maximize nDCG@topk on this run (use on DEV split)")
|
|
ap.add_argument("--gamma_grid", default="0.5,0.6,0.7,0.8",
|
|
help="Comma-separated gamma values to try if --tune_gamma (e.g., '0.4,0.5,0.6,0.7,0.8')")
|
|
|
|
args = ap.parse_args()
|
|
|
|
ensure_dir(args.out_dir)
|
|
chunks = load_chunks(args.chunks)
|
|
queries, qrels = load_qrels(args.qrels)
|
|
qids = list(qrels.keys())
|
|
|
|
fmt = args.format if args.format!="auto" else detect_family(args.model_id)
|
|
print(f"[INFO] retriever={args.model_id} | format={fmt} | reranker={args.reranker_model}")
|
|
|
|
# 1) Retriever: build/load index
|
|
from sentence_transformers import SentenceTransformer, CrossEncoder
|
|
enc = SentenceTransformer(args.model_id)
|
|
|
|
if args.faiss_index and os.path.exists(args.faiss_index):
|
|
index = faiss.read_index(args.faiss_index)
|
|
assert index.ntotal == len(chunks), "Index size ≠ chunks length."
|
|
doc_embs = None # not needed unless fuse=wsum
|
|
else:
|
|
passages = format_passages([c["text"] for c in chunks], fmt)
|
|
doc_embs = enc.encode(passages, batch_size=args.batch, convert_to_numpy=True,
|
|
normalize_embeddings=True).astype("float32")
|
|
X = doc_embs.astype("float32"); faiss.normalize_L2(X)
|
|
index = faiss.IndexFlatIP(X.shape[1]); index.add(X)
|
|
faiss.write_index(index, os.path.join(args.out_dir, "recording_text.index"))
|
|
|
|
# Encode queries
|
|
q_texts = []
|
|
qids = list(qrels.keys())
|
|
for qid in qids:
|
|
q = queries.get(qid, "")
|
|
q_texts.append(f"query: {q.strip()}" if args.format == 'e5' else q)
|
|
q_embs = enc.encode(q_texts, batch_size=128, convert_to_numpy=True, normalize_embeddings=True).astype("float32")
|
|
|
|
# First-stage search
|
|
M = max(args.rerank_topM, args.topk)
|
|
sims, ids = index.search(q_embs, M)
|
|
|
|
# caches for gamma tuning / rebuild
|
|
BASE_SIMS = [] # list[np.ndarray] (size M) per query
|
|
CE_SCORES = [] # list[np.ndarray] (size M) per query
|
|
BASE_KEYS = [] # list[list[Tuple[file, chunk_index]]] per query
|
|
|
|
|
|
# 2) Cross-encoder re-rank
|
|
ce = CrossEncoder(args.reranker_model, max_length=512)
|
|
run_base, run_rer = {}, {}
|
|
|
|
for qi, qid in enumerate(qids):
|
|
rows = ids[qi].tolist()
|
|
base_keys = [(chunks[r]["file"], int(chunks[r]["chunk_index"])) for r in rows]
|
|
run_base[qid] = base_keys
|
|
|
|
cand_texts = [(chunks[r]["text"] or "").replace("\n"," ").strip() for r in rows]
|
|
pairs = [(q_texts[qi], t) for t in cand_texts]
|
|
ce_scores = ce.predict(pairs, batch_size=32).astype(np.float32)
|
|
|
|
# if args.fuse == "none":
|
|
# order = np.argsort(-ce_scores)
|
|
# else:
|
|
# base_sims = sims[qi].astype(np.float32)
|
|
# fused = args.gamma * zscore(ce_scores) + (1.0 - args.gamma) * zscore(base_sims)
|
|
# order = np.argsort(-fused)
|
|
|
|
# order = order[:args.topk]
|
|
# run_rer[qid] = [base_keys[i] for i in order]
|
|
|
|
# --- cache for possible tuning ---
|
|
BASE_KEYS.append(base_keys)
|
|
BASE_SIMS.append(sims[qi].astype(np.float32))
|
|
CE_SCORES.append(ce_scores)
|
|
|
|
# If not tuning, produce final ranked list now
|
|
if not args.tune_gamma:
|
|
if args.fuse == "none":
|
|
order = np.argsort(-ce_scores)
|
|
else:
|
|
fused = args.gamma * zscore(ce_scores) + (1.0 - args.gamma) * zscore(sims[qi].astype(np.float32))
|
|
order = np.argsort(-fused)
|
|
|
|
order = order[:args.topk]
|
|
run_rer[qid] = [base_keys[i] for i in order]
|
|
|
|
# Only meaningful if we're fusing scores
|
|
if args.tune_gamma:
|
|
if args.fuse == "none":
|
|
print("[WARN] --tune_gamma ignored because --fuse=none")
|
|
else:
|
|
def ndcg_k(ranked_keys, gains, k):
|
|
import math
|
|
dcg = 0.0
|
|
for i, d in enumerate(ranked_keys[:k], 1):
|
|
g = gains.get(d, 0)
|
|
if g > 0:
|
|
dcg += (2**g - 1) / math.log2(i + 1)
|
|
ideal = sorted(gains.values(), reverse=True)
|
|
idcg = 0.0
|
|
for i, g in enumerate(ideal[:k], 1):
|
|
idcg += (2**g - 1) / math.log2(i + 1)
|
|
return dcg / idcg if idcg > 0 else 0.0
|
|
|
|
gammas = [float(x) for x in args.gamma_grid.split(",") if x.strip()]
|
|
best_g, best_ndcg = args.gamma, -1.0
|
|
for g in gammas:
|
|
ndcgs = []
|
|
for qi, qid in enumerate(qids):
|
|
fused = g * zscore(CE_SCORES[qi]) + (1.0 - g) * zscore(BASE_SIMS[qi])
|
|
order = np.argsort(-fused)[:args.topk]
|
|
cand = [BASE_KEYS[qi][i] for i in order]
|
|
ndcgs.append(ndcg_k(cand, qrels[qid], k=args.topk))
|
|
m = float(np.mean(ndcgs)) if ndcgs else 0.0
|
|
if m > best_ndcg:
|
|
best_ndcg, best_g = m, g
|
|
print(f"[TUNE] gamma={best_g:.2f} (nDCG@{args.topk}={best_ndcg:.4f})")
|
|
args.gamma = best_g
|
|
|
|
# Rebuild reranked run using tuned gamma
|
|
run_rer = {}
|
|
for qi, qid in enumerate(qids):
|
|
fused = args.gamma * zscore(CE_SCORES[qi]) + (1.0 - args.gamma) * zscore(BASE_SIMS[qi])
|
|
order = np.argsort(-fused)[:args.topk]
|
|
run_rer[qid] = [BASE_KEYS[qi][i] for i in order]
|
|
|
|
|
|
# 3) Evaluate base vs re-ranked
|
|
k_vals = (1,3,5,10)
|
|
base_metrics = evaluate(run_base, qrels, k_vals)
|
|
rer_metrics = evaluate(run_rer, qrels, k_vals)
|
|
|
|
print("\n[BASE] ", {k: round(v,4) for k,v in base_metrics.items()})
|
|
print("[RERANK]", {k: round(v,4) for k,v in rer_metrics.items()})
|
|
|
|
# Save CSV
|
|
import csv
|
|
with open(os.path.join(args.out_dir, "rerank_metrics_recording.csv"), "w", newline="", encoding="utf-8") as f:
|
|
cols = ["stage"] + list(base_metrics.keys())
|
|
w = csv.DictWriter(f, fieldnames=cols); w.writeheader()
|
|
w.writerow({"stage":"base", **{k: round(v,6) for k,v in base_metrics.items()}})
|
|
w.writerow({"stage":"rerank", **{k: round(v,6) for k,v in rer_metrics.items()}})
|
|
print(f"[OK] Saved: {os.path.join(args.out_dir, 'rerank_metrics_recording.csv')}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|