143 lines
6.0 KiB
Python
143 lines
6.0 KiB
Python
import os, json, argparse, csv, numpy as np, faiss, random, sys
|
|
|
|
def load_chunks(path):
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
chunks = json.load(f)
|
|
# minimal hygiene
|
|
for c in chunks:
|
|
c.setdefault("text", "")
|
|
if "chunk_index" not in c:
|
|
raise ValueError("chunks.json must include 'chunk_index' per entry")
|
|
if "file" not in c:
|
|
raise ValueError("chunks.json must include 'file' per entry")
|
|
return chunks
|
|
|
|
def load_queries(path):
|
|
qs = []
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
r = json.loads(line)
|
|
print(line)
|
|
qs.append((r["qid"], r["query"]))
|
|
return qs
|
|
|
|
def build_index(embs: np.ndarray) -> faiss.IndexFlatIP:
|
|
embs = embs.astype("float32")
|
|
faiss.normalize_L2(embs)
|
|
idx = faiss.IndexFlatIP(embs.shape[1])
|
|
idx.add(embs)
|
|
return idx
|
|
|
|
def maybe_load_index(index_path: str):
|
|
return faiss.read_index(index_path) if (index_path and os.path.exists(index_path)) else None
|
|
|
|
def ensure_dir(p): os.makedirs(p, exist_ok=True)
|
|
|
|
def detect_family(model_id: str) -> str:
|
|
mid = model_id.lower()
|
|
if "e5" in mid and "multilingual" not in mid:
|
|
return "e5" # e5-base-v2 (EN)
|
|
if "bge" in mid:
|
|
return "bge" # BAAI/bge-*-en*
|
|
# MPNet, MiniLM, GTE, etc. -> no special formatting
|
|
return "none"
|
|
|
|
def format_passages(texts, fmt: str):
|
|
if fmt == "e5":
|
|
return [f"passage: {t.strip()}" for t in texts]
|
|
if fmt == "bge":
|
|
return [f"Represent this document for retrieval: {t.strip()}" for t in texts]
|
|
return [t if isinstance(t, str) else str(t) for t in texts]
|
|
|
|
def format_queries(texts, fmt: str):
|
|
if fmt == "e5":
|
|
return [f"query: {t.strip()}" for t in texts]
|
|
if fmt == "bge":
|
|
return [f"Represent this query for retrieval: {t.strip()}" for t in texts]
|
|
return [t.strip() for t in texts]
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--chunks", default='./lecture_chunks.json', help="chunks.json with fields: file, chunk_index, text")
|
|
ap.add_argument("--queries", default='./queries.jsonl', help="queries.jsonl with {qid, query}")
|
|
ap.add_argument("--out_dir", default='./', help="Where to write label_pool_recording.csv and index if built")
|
|
ap.add_argument("--faiss_index", default="./out/embeddings_v0.2_enhanced.faiss", help="Optional existing FAISS index for chunks; if missing, build one")
|
|
ap.add_argument("--model_id", default="sentence-transformers/all-mpnet-base-v2",
|
|
help="e.g., sentence-transformers/all-mpnet-base-v2, BAAI/bge-base-en-v1.5, Alibaba-NLP/gte-base-en-v1.5, intfloat/e5-base-v2")
|
|
ap.add_argument("--format", default="none", choices=["auto","none","e5","bge"],
|
|
help="Query/passage prompt formatting. 'auto' infers from model_id.")
|
|
ap.add_argument("--topM", type=int, default=20) # How many candidate chunks are retreived for a query
|
|
ap.add_argument("--batch", type=int, default=128)
|
|
ap.add_argument("--seed", type=int, default=42)
|
|
args = ap.parse_args()
|
|
|
|
random.seed(args.seed); np.random.seed(args.seed)
|
|
os.makedirs(args.out_dir, exist_ok=True)
|
|
|
|
chunks = load_chunks(args.chunks)
|
|
print(args.queries)
|
|
queries = load_queries(args.queries)
|
|
|
|
# choose formatting
|
|
fmt = args.format
|
|
if fmt == "auto":
|
|
fmt = detect_family(args.model_id)
|
|
print(f"[INFO] model_id={args.model_id} | format={fmt}")
|
|
|
|
# encoder
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
enc = SentenceTransformer(args.model_id)
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to load model {args.model_id}: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# index: use provided or build
|
|
index = maybe_load_index(args.faiss_index)
|
|
if index is None:
|
|
passages = format_passages([c["text"] for c in chunks], fmt)
|
|
doc_embs = enc.encode(passages, batch_size=args.batch, convert_to_numpy=True,
|
|
normalize_embeddings=True).astype("float32")
|
|
index = build_index(doc_embs)
|
|
faiss_path = os.path.join(args.out_dir, "recording_text.index")
|
|
faiss.write_index(index, faiss_path)
|
|
# save docmap for safety
|
|
with open(os.path.join(args.out_dir, "recording_docmap.json"), "w", encoding="utf-8") as f:
|
|
json.dump([{"row": i, "file": c["file"], "chunk_index": int(c["chunk_index"])} for i,c in enumerate(chunks)],
|
|
f, ensure_ascii=False, indent=2)
|
|
print(f"[OK] Built index → {faiss_path}")
|
|
else:
|
|
assert index.ntotal == len(chunks), \
|
|
f"FAISS has {index.ntotal} vectors but chunks.json has {len(chunks)} entries."
|
|
|
|
# queries → embeddings
|
|
q_texts = format_queries([q for _, q in queries], fmt)
|
|
q_embs = enc.encode(q_texts, batch_size=args.batch, convert_to_numpy=True,
|
|
normalize_embeddings=True).astype("float32")
|
|
|
|
# retrieve the topM chunks based on similarity for each q_embedding.
|
|
# Gives a list of similarity scores in descening order and their
|
|
# corresponding indices in chunks.json
|
|
sims, ids = index.search(q_embs, args.topM)
|
|
|
|
# write labeling CSV
|
|
out_csv = os.path.join(args.out_dir, "label_pool_recording.csv")
|
|
with open(out_csv, "w", newline="", encoding="utf-8") as f:
|
|
w = csv.writer(f)
|
|
w.writerow(["qid","query","file","chunk_index","candidate_rank","base_score","preview","rel"])
|
|
for (qid, q), row_ids, row_sims in zip(queries, ids, sims):
|
|
seen = set()
|
|
rank = 0
|
|
for sid, sc in zip(row_ids.tolist(), row_sims.tolist()):
|
|
c = chunks[sid]
|
|
key = (c["file"], int(c["chunk_index"]))
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
rank += 1
|
|
preview = (c["text"] or "").replace("\n"," ")[:300]
|
|
w.writerow([qid, q, c["file"], int(c["chunk_index"]), rank, round(float(sc),4), preview, ""])
|
|
print(f"[OK] Wrote labeling sheet: {out_csv}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |