diff --git a/benchmarks/run_benchmark.py b/benchmarks/run_benchmark.py index 1342a11..fe6d030 100644 --- a/benchmarks/run_benchmark.py +++ b/benchmarks/run_benchmark.py @@ -1,5 +1,17 @@ #!/usr/bin/env python3 -"""Simple BM25 vs Bayesian BM25 benchmark runner.""" +"""BM25 vs Bayesian BM25 benchmark runner. + +Evaluates ranking quality and probability calibration across: + 1. Raw BM25 (baseline) + 2. Bayesian BM25 (fixed params) + 3. Bayesian BM25 (batch-fitted params via BayesianProbabilityTransform) + 4. Hybrid OR / AND (when embeddings available) + 5. Balanced log-odds fusion + 6. Gated log-odds fusion (relu, swish) + 7. Learned-weight log-odds fusion (LearnableLogOddsWeights) + 8. Attention-weighted log-odds fusion (AttentionLogOddsWeights) + 9. Calibration diagnostics (ECE, Brier, reliability diagram) +""" from __future__ import annotations @@ -7,7 +19,7 @@ import json import math import time -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Dict, Iterable, List, Optional, Tuple @@ -18,15 +30,15 @@ class DocRecord: doc_id: str text: str - embedding: List[float] + embedding: List[float] = field(default_factory=list) @dataclass class QueryRecord: query_id: str text: str - terms: Optional[List[str]] - embedding: Optional[List[float]] + terms: Optional[List[str]] = None + embedding: Optional[List[float]] = None def load_jsonl(path: Path) -> Iterable[dict]: @@ -83,13 +95,20 @@ def load_qrels(path: Path) -> Dict[str, Dict[str, float]]: return qrels with path.open("r", encoding="utf-8") as handle: + first_line = True for line in handle: line = line.strip() if not line or line.startswith("#"): continue - parts = line.split() + parts = line.split("\t") if "\t" in line else line.split() if len(parts) < 3: continue + if first_line: + first_line = False + try: + float(parts[2]) + except ValueError: + continue qid, did, rel_str = parts[0], parts[1], parts[2] rel = float(rel_str) qrels.setdefault(qid, {})[did] = rel @@ -102,6 +121,30 @@ def parse_cutoffs(raw: str) -> List[int]: return unique +def encode_embeddings( + docs: List[DocRecord], + queries: List[QueryRecord], + model_name: str, + batch_size: int, +) -> None: + from sentence_transformers import SentenceTransformer + + print(f"Loading embedding model: {model_name}") + model = SentenceTransformer(model_name) + + doc_texts = [doc.text for doc in docs] + print(f"Encoding {len(doc_texts)} documents...") + doc_embs = model.encode(doc_texts, batch_size=batch_size, show_progress_bar=True) + for i, doc in enumerate(docs): + doc.embedding = doc_embs[i].tolist() + + query_texts = [q.text for q in queries] + print(f"Encoding {len(query_texts)} queries...") + query_embs = model.encode(query_texts, batch_size=batch_size, show_progress_bar=True) + for i, q in enumerate(queries): + q.embedding = query_embs[i].tolist() + + def build_corpus(docs: List[DocRecord]) -> bb.Corpus: corpus = bb.Corpus(None) for doc in docs: @@ -198,6 +241,470 @@ def evaluate( return metrics +def evaluate_hybrid( + queries: List[QueryRecord], + docs: List[bb.Document], + scorer_name: str, + score_fn, + qrels: Dict[str, Dict[str, float]], + tokenizer: bb.Tokenizer, + cutoffs: List[int], +) -> Dict[str, float]: + metrics = {f"map@{k}": 0.0 for k in cutoffs} + metrics.update({f"ndcg@{k}": 0.0 for k in cutoffs}) + metrics.update({f"mrr@{k}": 0.0 for k in cutoffs}) + + counted = 0 + start = time.perf_counter() + for query in queries: + rel_map = qrels.get(query.query_id, {}) + if not rel_map: + continue + if query.embedding is None: + continue + terms = query.terms or tokenizer.tokenize(query.text) + scores = [(doc.id, score_fn(terms, query.embedding, doc)) for doc in docs] + ranked = rank_docs(scores) + for k in cutoffs: + metrics[f"map@{k}"] += average_precision_at_k(ranked, rel_map, k) + metrics[f"ndcg@{k}"] += ndcg_at_k(ranked, rel_map, k) + metrics[f"mrr@{k}"] += mrr_at_k(ranked, rel_map, k) + counted += 1 + + elapsed = time.perf_counter() - start + if counted == 0: + return {"scorer": scorer_name, "queries": 0, "elapsed_s": elapsed} + + for key in list(metrics.keys()): + metrics[key] /= counted + metrics["scorer"] = scorer_name + metrics["queries"] = counted + metrics["elapsed_s"] = elapsed + return metrics + + +def evaluate_balanced_fusion( + queries: List[QueryRecord], + docs: List[bb.Document], + bayes: bb.BayesianBM25Scorer, + vector: bb.VectorScorer, + qrels: Dict[str, Dict[str, float]], + tokenizer: bb.Tokenizer, + cutoffs: List[int], + weight: float = 0.5, +) -> Dict[str, float]: + metrics = {f"map@{k}": 0.0 for k in cutoffs} + metrics.update({f"ndcg@{k}": 0.0 for k in cutoffs}) + metrics.update({f"mrr@{k}": 0.0 for k in cutoffs}) + + doc_ids = [doc.id for doc in docs] + counted = 0 + start = time.perf_counter() + for query in queries: + rel_map = qrels.get(query.query_id, {}) + if not rel_map: + continue + if query.embedding is None: + continue + terms = query.terms or tokenizer.tokenize(query.text) + + sparse_probs = [bayes.score(terms, doc) for doc in docs] + dense_sims = [vector.score(query.embedding, doc) for doc in docs] + fused = bb.balanced_log_odds_fusion(sparse_probs, dense_sims, weight) + + scores = list(zip(doc_ids, fused)) + ranked = rank_docs(scores) + for k in cutoffs: + metrics[f"map@{k}"] += average_precision_at_k(ranked, rel_map, k) + metrics[f"ndcg@{k}"] += ndcg_at_k(ranked, rel_map, k) + metrics[f"mrr@{k}"] += mrr_at_k(ranked, rel_map, k) + counted += 1 + + elapsed = time.perf_counter() - start + if counted == 0: + return {"scorer": "balanced_fusion", "queries": 0, "elapsed_s": elapsed} + + for key in list(metrics.keys()): + metrics[key] /= counted + metrics["scorer"] = "balanced_fusion" + metrics["queries"] = counted + metrics["elapsed_s"] = elapsed + return metrics + + +def evaluate_gated_fusion( + queries: List[QueryRecord], + docs: List[bb.Document], + bayes: bb.BayesianBM25Scorer, + vector: bb.VectorScorer, + qrels: Dict[str, Dict[str, float]], + tokenizer: bb.Tokenizer, + cutoffs: List[int], + gating: str, +) -> Dict[str, float]: + metrics = {f"map@{k}": 0.0 for k in cutoffs} + metrics.update({f"ndcg@{k}": 0.0 for k in cutoffs}) + metrics.update({f"mrr@{k}": 0.0 for k in cutoffs}) + + counted = 0 + start = time.perf_counter() + for query in queries: + rel_map = qrels.get(query.query_id, {}) + if not rel_map: + continue + if query.embedding is None: + continue + terms = query.terms or tokenizer.tokenize(query.text) + + doc_scores: List[Tuple[str, float]] = [] + for doc in docs: + sparse_prob = bayes.score(terms, doc) + dense_sim = vector.score(query.embedding, doc) + dense_prob = bb.cosine_to_probability(dense_sim) + fused = bb.log_odds_conjunction( + [sparse_prob, dense_prob], gating=gating, + ) + doc_scores.append((doc.id, fused)) + + ranked = rank_docs(doc_scores) + for k in cutoffs: + metrics[f"map@{k}"] += average_precision_at_k(ranked, rel_map, k) + metrics[f"ndcg@{k}"] += ndcg_at_k(ranked, rel_map, k) + metrics[f"mrr@{k}"] += mrr_at_k(ranked, rel_map, k) + counted += 1 + + elapsed = time.perf_counter() - start + scorer_name = f"gated_{gating}" + if counted == 0: + return {"scorer": scorer_name, "queries": 0, "elapsed_s": elapsed} + + for key in list(metrics.keys()): + metrics[key] /= counted + metrics["scorer"] = scorer_name + metrics["queries"] = counted + metrics["elapsed_s"] = elapsed + return metrics + + +def evaluate_learned_weights_fusion( + queries: List[QueryRecord], + docs: List[bb.Document], + bayes: bb.BayesianBM25Scorer, + vector: bb.VectorScorer, + qrels: Dict[str, Dict[str, float]], + tokenizer: bb.Tokenizer, + cutoffs: List[int], +) -> Dict[str, float]: + # Phase 1: collect training data from all queries with qrels + train_probs: List[List[float]] = [] + train_labels: List[float] = [] + for query in queries: + rel_map = qrels.get(query.query_id, {}) + if not rel_map: + continue + if query.embedding is None: + continue + terms = query.terms or tokenizer.tokenize(query.text) + for doc in docs: + did = doc.id + if did not in rel_map: + continue + sparse_prob = bayes.score(terms, doc) + dense_sim = vector.score(query.embedding, doc) + dense_prob = bb.cosine_to_probability(dense_sim) + train_probs.append([sparse_prob, dense_prob]) + train_labels.append(1.0 if rel_map[did] > 0 else 0.0) + + if len(train_probs) < 4: + return {"scorer": "learned_weights", "queries": 0, "elapsed_s": 0.0} + + # Phase 2: learn weights + learner = bb.LearnableLogOddsWeights(2) + learner.fit(train_probs, train_labels) + weights = learner.weights + + # Phase 3: evaluate with learned weights + metrics = {f"map@{k}": 0.0 for k in cutoffs} + metrics.update({f"ndcg@{k}": 0.0 for k in cutoffs}) + metrics.update({f"mrr@{k}": 0.0 for k in cutoffs}) + + counted = 0 + start = time.perf_counter() + for query in queries: + rel_map = qrels.get(query.query_id, {}) + if not rel_map: + continue + if query.embedding is None: + continue + terms = query.terms or tokenizer.tokenize(query.text) + + doc_scores: List[Tuple[str, float]] = [] + for doc in docs: + sparse_prob = bayes.score(terms, doc) + dense_sim = vector.score(query.embedding, doc) + dense_prob = bb.cosine_to_probability(dense_sim) + fused = bb.log_odds_conjunction( + [sparse_prob, dense_prob], weights=weights, + ) + doc_scores.append((doc.id, fused)) + + ranked = rank_docs(doc_scores) + for k in cutoffs: + metrics[f"map@{k}"] += average_precision_at_k(ranked, rel_map, k) + metrics[f"ndcg@{k}"] += ndcg_at_k(ranked, rel_map, k) + metrics[f"mrr@{k}"] += mrr_at_k(ranked, rel_map, k) + counted += 1 + + elapsed = time.perf_counter() - start + if counted == 0: + return {"scorer": "learned_weights", "queries": 0, "elapsed_s": elapsed} + + for key in list(metrics.keys()): + metrics[key] /= counted + metrics["scorer"] = "learned_weights" + metrics["queries"] = counted + metrics["elapsed_s"] = elapsed + metrics["weights"] = weights + return metrics + + +def evaluate_attention_fusion( + queries: List[QueryRecord], + docs: List[bb.Document], + bayes: bb.BayesianBM25Scorer, + vector: bb.VectorScorer, + qrels: Dict[str, Dict[str, float]], + tokenizer: bb.Tokenizer, + cutoffs: List[int], +) -> Dict[str, float]: + """Evaluate AttentionLogOddsWeights fusion. + + Uses query length as a single query feature so the attention mechanism + can learn query-dependent signal weights. + """ + n_signals = 2 + n_query_features = 1 # query length + + # Phase 1: collect training data + train_probs: List[List[float]] = [] + train_labels: List[float] = [] + train_features: List[float] = [] + train_query_ids: List[int] = [] + qid_map: Dict[str, int] = {} + + for query in queries: + rel_map = qrels.get(query.query_id, {}) + if not rel_map: + continue + if query.embedding is None: + continue + terms = query.terms or tokenizer.tokenize(query.text) + qlen = float(len(terms)) + if query.query_id not in qid_map: + qid_map[query.query_id] = len(qid_map) + qid_int = qid_map[query.query_id] + + for doc in docs: + did = doc.id + if did not in rel_map: + continue + sparse_prob = bayes.score(terms, doc) + dense_sim = vector.score(query.embedding, doc) + dense_prob = bb.cosine_to_probability(dense_sim) + train_probs.append([sparse_prob, dense_prob]) + train_labels.append(1.0 if rel_map[did] > 0 else 0.0) + train_features.append(qlen) + train_query_ids.append(qid_int) + + if len(train_probs) < 4: + return {"scorer": "attention", "queries": 0, "elapsed_s": 0.0} + + # Phase 2: fit attention model + n_samples = len(train_probs) + flat_probs = [p for pair in train_probs for p in pair] + attn = bb.AttentionLogOddsWeights(n_signals, n_query_features) + attn.fit( + flat_probs, train_labels, train_features, + n_samples, + query_ids=train_query_ids, + ) + + # Phase 3: evaluate + metrics = {f"map@{k}": 0.0 for k in cutoffs} + metrics.update({f"ndcg@{k}": 0.0 for k in cutoffs}) + metrics.update({f"mrr@{k}": 0.0 for k in cutoffs}) + + counted = 0 + start = time.perf_counter() + for query in queries: + rel_map = qrels.get(query.query_id, {}) + if not rel_map: + continue + if query.embedding is None: + continue + terms = query.terms or tokenizer.tokenize(query.text) + qlen = float(len(terms)) + + # Score all docs: collect probs matrix for this query + doc_ids: List[str] = [] + query_probs: List[float] = [] + for doc in docs: + sparse_prob = bayes.score(terms, doc) + dense_sim = vector.score(query.embedding, doc) + dense_prob = bb.cosine_to_probability(dense_sim) + doc_ids.append(doc.id) + query_probs.extend([sparse_prob, dense_prob]) + + n_docs_q = len(doc_ids) + # combine(probs, m=n_rows, query_features, m_q=n_feature_rows) + # Pass m_q=1 so the single query feature row is broadcast to all docs + query_features = [qlen] + fused_scores = attn.combine( + query_probs, n_docs_q, + query_features, 1, + use_averaged=True, + ) + + doc_scores = list(zip(doc_ids, fused_scores)) + ranked = rank_docs(doc_scores) + for k in cutoffs: + metrics[f"map@{k}"] += average_precision_at_k(ranked, rel_map, k) + metrics[f"ndcg@{k}"] += ndcg_at_k(ranked, rel_map, k) + metrics[f"mrr@{k}"] += mrr_at_k(ranked, rel_map, k) + counted += 1 + + elapsed = time.perf_counter() - start + if counted == 0: + return {"scorer": "attention", "queries": 0, "elapsed_s": elapsed} + + for key in list(metrics.keys()): + metrics[key] /= counted + metrics["scorer"] = "attention" + metrics["queries"] = counted + metrics["elapsed_s"] = elapsed + return metrics + + +def evaluate_fitted_bayesian( + queries: List[QueryRecord], + docs: List[bb.Document], + bm25: bb.BM25Scorer, + qrels: Dict[str, Dict[str, float]], + tokenizer: bb.Tokenizer, + cutoffs: List[int], + alpha: float, + beta: float, +) -> Dict[str, float]: + # Phase 1: collect (score, label) pairs for fitting + train_scores: List[float] = [] + train_labels: List[float] = [] + for query in queries: + rel_map = qrels.get(query.query_id, {}) + if not rel_map: + continue + terms = query.terms or tokenizer.tokenize(query.text) + for doc in docs: + did = doc.id + if did not in rel_map: + continue + raw_score = bm25.score(terms, doc) + if raw_score > 0.0: + train_scores.append(raw_score) + train_labels.append(1.0 if rel_map[did] > 0 else 0.0) + + if len(train_scores) < 4: + return {"scorer": "bayesian_fitted", "queries": 0, "elapsed_s": 0.0} + + # Phase 2: fit BayesianProbabilityTransform + transform = bb.BayesianProbabilityTransform(alpha=alpha, beta=beta) + transform.fit(train_scores, train_labels) + fitted_alpha = transform.alpha + fitted_beta = transform.beta + + # Phase 3: evaluate with fitted params + fitted_bayes = bb.BayesianBM25Scorer(bm25, fitted_alpha, fitted_beta) + + metrics = {f"map@{k}": 0.0 for k in cutoffs} + metrics.update({f"ndcg@{k}": 0.0 for k in cutoffs}) + metrics.update({f"mrr@{k}": 0.0 for k in cutoffs}) + + all_probs: List[float] = [] + all_labels: List[float] = [] + counted = 0 + start = time.perf_counter() + for query in queries: + rel_map = qrels.get(query.query_id, {}) + if not rel_map: + continue + terms = query.terms or tokenizer.tokenize(query.text) + scores = [(doc.id, fitted_bayes.score(terms, doc)) for doc in docs] + ranked = rank_docs(scores) + for k in cutoffs: + metrics[f"map@{k}"] += average_precision_at_k(ranked, rel_map, k) + metrics[f"ndcg@{k}"] += ndcg_at_k(ranked, rel_map, k) + metrics[f"mrr@{k}"] += mrr_at_k(ranked, rel_map, k) + + for did, prob in scores: + if did in rel_map and prob > 0.0: + all_probs.append(prob) + all_labels.append(1.0 if rel_map[did] > 0 else 0.0) + counted += 1 + + elapsed = time.perf_counter() - start + if counted == 0: + return {"scorer": "bayesian_fitted", "queries": 0, "elapsed_s": elapsed} + + for key in list(metrics.keys()): + metrics[key] /= counted + metrics["scorer"] = "bayesian_fitted" + metrics["queries"] = counted + metrics["elapsed_s"] = elapsed + metrics["fitted_alpha"] = fitted_alpha + metrics["fitted_beta"] = fitted_beta + + if all_probs: + metrics["ece"] = bb.expected_calibration_error(all_probs, all_labels) + metrics["brier"] = bb.brier_score(all_probs, all_labels) + + return metrics + + +def collect_calibration_data( + queries: List[QueryRecord], + docs: List[bb.Document], + scorer_name: str, + score_fn, + qrels: Dict[str, Dict[str, float]], + tokenizer: bb.Tokenizer, +) -> Dict[str, float]: + all_probs: List[float] = [] + all_labels: List[float] = [] + for query in queries: + rel_map = qrels.get(query.query_id, {}) + if not rel_map: + continue + terms = query.terms or tokenizer.tokenize(query.text) + for doc in docs: + did = doc.id + if did not in rel_map: + continue + prob = score_fn(terms, doc) + if prob > 0.0: + all_probs.append(prob) + all_labels.append(1.0 if rel_map[did] > 0 else 0.0) + + if not all_probs: + return {"scorer": scorer_name} + + report = bb.calibration_report(all_probs, all_labels) + return { + "scorer": scorer_name, + "n_samples": report.n_samples, + "ece": report.ece, + "brier": report.brier, + } + + def format_table(results: List[Dict[str, float]], cutoffs: List[int]) -> str: headers = ["scorer", "queries", "elapsed_s"] for k in cutoffs: @@ -215,6 +722,21 @@ def format_table(results: List[Dict[str, float]], cutoffs: List[int]) -> str: return "\n".join(lines) +def format_calibration_table(calibration_results: List[Dict[str, float]]) -> str: + headers = ["scorer", "n_samples", "ece", "brier"] + lines = ["\t".join(headers)] + for row in calibration_results: + values = [] + for h in headers: + val = row.get(h, "") + if isinstance(val, float): + values.append(f"{val:.6f}") + else: + values.append(str(val)) + lines.append("\t".join(values)) + return "\n".join(lines) + + def main() -> None: parser = argparse.ArgumentParser(description="BM25 vs Bayesian BM25 benchmark runner") parser.add_argument("--docs", type=Path, required=True, help="JSONL docs with doc_id + text") @@ -227,6 +749,9 @@ def main() -> None: parser.add_argument("--query-text", default="text") parser.add_argument("--query-terms", default=None) parser.add_argument("--query-embedding", default=None) + parser.add_argument("--embedding-model", default=None, + help="sentence-transformers model name (e.g. all-MiniLM-L6-v2)") + parser.add_argument("--embedding-batch-size", type=int, default=64) parser.add_argument("--bm25-k1", type=float, default=1.2) parser.add_argument("--bm25-b", type=float, default=0.75) parser.add_argument("--alpha", type=float, default=1.0) @@ -252,6 +777,9 @@ def main() -> None: if args.max_queries: queries = queries[: args.max_queries] + if args.embedding_model: + encode_embeddings(docs, queries, args.embedding_model, args.embedding_batch_size) + qrels = load_qrels(args.qrels) corpus = build_corpus(docs) @@ -269,76 +797,143 @@ def main() -> None: hybrid_or = bb.HybridScorer(bayes, vector) hybrid_and = bb.HybridScorer(bayes, vector) + # ----------------------------------------------------------------------- + # Ranking evaluation + # ----------------------------------------------------------------------- results = [] results.append( evaluate( - queries, - doc_objs, - "bm25", + queries, doc_objs, "bm25", lambda terms, doc: bm25.score(terms, doc), - qrels, - tokenizer, - cutoffs, + qrels, tokenizer, cutoffs, ) ) results.append( evaluate( - queries, - doc_objs, - "bayesian", + queries, doc_objs, "bayesian", lambda terms, doc: bayes.score(terms, doc), - qrels, - tokenizer, - cutoffs, + qrels, tokenizer, cutoffs, + ) + ) + + # Bayesian with batch-fitted parameters + results.append( + evaluate_fitted_bayesian( + queries, doc_objs, bm25, qrels, tokenizer, cutoffs, + args.alpha, args.beta, ) ) if has_embeddings: - query_embedding_map = {} - for q in queries: - if q.embedding: - query_embedding_map[q.query_id] = q.embedding - - def make_hybrid_fn(scorer_method): - def fn(terms, doc): - qid = None - for q in queries: - q_terms = q.terms or tokenizer.tokenize(q.text) - if q_terms == list(terms): - qid = q.query_id - break - emb = query_embedding_map.get(qid, [0.0] * len(doc.embedding)) - return scorer_method(terms, emb, doc) - return fn + results.append( + evaluate_hybrid( + queries, doc_objs, "hybrid_or", + hybrid_or.score_or, qrels, tokenizer, cutoffs, + ) + ) + results.append( + evaluate_hybrid( + queries, doc_objs, "hybrid_and", + hybrid_and.score_and, qrels, tokenizer, cutoffs, + ) + ) + results.append( + evaluate_balanced_fusion( + queries, doc_objs, bayes, vector, + qrels, tokenizer, cutoffs, + ) + ) + + # Gated fusion variants + for gating in ("relu", "swish"): + results.append( + evaluate_gated_fusion( + queries, doc_objs, bayes, vector, + qrels, tokenizer, cutoffs, gating, + ) + ) + # Learned-weight fusion results.append( - evaluate( - queries, - doc_objs, - "hybrid_or", - make_hybrid_fn(hybrid_or.score_or), - qrels, - tokenizer, - cutoffs, + evaluate_learned_weights_fusion( + queries, doc_objs, bayes, vector, + qrels, tokenizer, cutoffs, ) ) + + # Attention-weighted fusion results.append( - evaluate( - queries, - doc_objs, - "hybrid_and", - make_hybrid_fn(hybrid_and.score_and), - qrels, - tokenizer, - cutoffs, + evaluate_attention_fusion( + queries, doc_objs, bayes, vector, + qrels, tokenizer, cutoffs, ) ) + # ----------------------------------------------------------------------- + # Ranking results table + # ----------------------------------------------------------------------- + print("=== Ranking Metrics ===") table = format_table(results, cutoffs) print(table) + # ----------------------------------------------------------------------- + # Calibration diagnostics + # ----------------------------------------------------------------------- + print("\n=== Calibration Metrics ===") + calibration_results = [] + calibration_results.append( + collect_calibration_data( + queries, doc_objs, "bayesian", + lambda terms, doc: bayes.score(terms, doc), + qrels, tokenizer, + ) + ) + + # Calibration for fitted Bayesian (if it ran) + fitted_row = next((r for r in results if r.get("scorer") == "bayesian_fitted"), None) + if fitted_row and "ece" in fitted_row: + calibration_results.append({ + "scorer": "bayesian_fitted", + "n_samples": fitted_row.get("n_samples", ""), + "ece": fitted_row["ece"], + "brier": fitted_row["brier"], + }) + + cal_table = format_calibration_table(calibration_results) + print(cal_table) + + # Reliability diagram for the Bayesian scorer + cal_data = calibration_results[0] + if "n_samples" in cal_data and cal_data["n_samples"]: + all_probs: List[float] = [] + all_labels: List[float] = [] + for query in queries: + rel_map = qrels.get(query.query_id, {}) + if not rel_map: + continue + terms = query.terms or tokenizer.tokenize(query.text) + for doc in doc_objs: + did = doc.id + if did not in rel_map: + continue + prob = bayes.score(terms, doc) + if prob > 0.0: + all_probs.append(prob) + all_labels.append(1.0 if rel_map[did] > 0 else 0.0) + + if all_probs: + report = bb.calibration_report(all_probs, all_labels) + print(f"\n{report.summary()}") + + # ----------------------------------------------------------------------- + # Output + # ----------------------------------------------------------------------- if args.output_json: - payload = {"cutoffs": cutoffs, "results": results} + payload = { + "cutoffs": cutoffs, + "results": results, + "calibration": calibration_results, + } args.output_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") diff --git a/src/attention_weights.rs b/src/attention_weights.rs new file mode 100644 index 0000000..035a9fe --- /dev/null +++ b/src/attention_weights.rs @@ -0,0 +1,428 @@ +use crate::fusion::{log_odds_conjunction, Gating}; +use crate::math_utils::{logit, min_max_normalize, safe_prob, sigmoid, softmax_rows}; + +/// Query-dependent signal weighting via attention (Paper 2, Section 8). +/// +/// Computes per-signal softmax attention weights from query features: +/// w_i(q) = softmax(W @ features + b)[i], then combines probability +/// signals via weighted log-odds conjunction. +pub struct AttentionLogOddsWeights { + n_signals: usize, + n_query_features: usize, + alpha: f64, + normalize: bool, + // W: (n_signals, n_query_features) stored row-major + w_matrix: Vec, + // b: (n_signals,) + bias: Vec, + // Online learning state + n_updates: usize, + grad_w_ema: Vec, + grad_b_ema: Vec, + // Polyak averaging + w_avg: Vec, + b_avg: Vec, +} + +impl AttentionLogOddsWeights { + /// Create new attention weights with Xavier initialization. + /// + /// Uses a fixed seed (0) for reproducible initialization. + pub fn new( + n_signals: usize, + n_query_features: usize, + alpha: f64, + normalize: bool, + ) -> Self { + assert!(n_signals >= 1, "n_signals must be >= 1, got {}", n_signals); + assert!( + n_query_features >= 1, + "n_query_features must be >= 1, got {}", + n_query_features + ); + + // Xavier-style initialization using a simple PRNG (same seed as Python's np.random.default_rng(0)) + let scale = 1.0 / (n_query_features as f64).sqrt(); + let total = n_signals * n_query_features; + let w_matrix = simple_normal_init(total, scale, 0); + + Self { + n_signals, + n_query_features, + alpha, + normalize, + w_matrix: w_matrix.clone(), + bias: vec![0.0; n_signals], + n_updates: 0, + grad_w_ema: vec![0.0; total], + grad_b_ema: vec![0.0; n_signals], + w_avg: w_matrix, + b_avg: vec![0.0; n_signals], + } + } + + pub fn n_signals(&self) -> usize { + self.n_signals + } + + pub fn n_query_features(&self) -> usize { + self.n_query_features + } + + pub fn alpha(&self) -> f64 { + self.alpha + } + + pub fn normalize(&self) -> bool { + self.normalize + } + + /// Weight matrix W of shape (n_signals, n_query_features). + pub fn weights_matrix(&self) -> Vec { + self.w_matrix.clone() + } + + /// Compute softmax attention weights from query features. + /// + /// query_features: flat array of shape (m * n_query_features) + /// Returns flat array of shape (m * n_signals) + fn compute_weights(&self, query_features: &[f64], m: usize, use_averaged: bool) -> Vec { + let n = self.n_signals; + let nqf = self.n_query_features; + let w = if use_averaged { &self.w_avg } else { &self.w_matrix }; + let b = if use_averaged { &self.b_avg } else { &self.bias }; + + // z = query_features @ W^T + b + // query_features: (m, nqf), W: (n, nqf), result: (m, n) + let mut z = vec![0.0; m * n]; + for row in 0..m { + for col in 0..n { + let mut val = b[col]; + for k in 0..nqf { + val += query_features[row * nqf + k] * w[col * nqf + k]; + } + z[row * n + col] = val; + } + } + + softmax_rows(&z, n) + } + + /// Per-column min-max normalization on logit array (m rows, n_signals cols). + fn normalize_logits_columns(&self, x: &mut [f64], m: usize) { + let n = self.n_signals; + for col in 0..n { + let column: Vec = (0..m).map(|row| x[row * n + col]).collect(); + let normalized = min_max_normalize(&column); + for row in 0..m { + x[row * n + col] = normalized[row]; + } + } + } + + /// Combine probability signals via query-dependent weighted log-odds. + /// + /// probs: flat array of shape (m * n_signals) for m candidates + /// query_features: flat array of shape (m_q * n_query_features) + /// If m_q < m, the last query feature row is broadcast. + pub fn combine( + &self, + probs: &[f64], + m: usize, + query_features: &[f64], + m_q: usize, + use_averaged: bool, + ) -> Vec { + let n = self.n_signals; + let weights = self.compute_weights(query_features, m_q, use_averaged); + + if m == 1 && !self.normalize { + // Single sample: just use weighted log-odds conjunction + let w_flat: Vec = (0..n).map(|j| weights[j]).collect(); + let row_probs: Vec = (0..n).map(|j| probs[j]).collect(); + return vec![log_odds_conjunction(&row_probs, Some(self.alpha), Some(&w_flat), Gating::NoGating)]; + } + + if self.normalize { + let scale = (n as f64).powf(self.alpha); + let mut x: Vec = probs.iter().map(|&p| logit(safe_prob(p))).collect(); + self.normalize_logits_columns(&mut x, m); + + let mut results = vec![0.0; m]; + for i in 0..m { + let wi_row = (i).min(m_q - 1); + let mut l_weighted = 0.0; + for j in 0..n { + l_weighted += weights[wi_row * n + j] * x[i * n + j]; + } + results[i] = sigmoid(scale * l_weighted); + } + return results; + } + + // Batched: each row has its own query-dependent weights + let mut results = vec![0.0; m]; + for i in 0..m { + let wi_row = (i).min(m_q - 1); + let w_slice: Vec = (0..n).map(|j| weights[wi_row * n + j]).collect(); + let row_probs: Vec = (0..n).map(|j| probs[i * n + j]).collect(); + results[i] = log_odds_conjunction(&row_probs, Some(self.alpha), Some(&w_slice), Gating::NoGating); + } + results + } + + /// Batch gradient descent on BCE loss to learn W and b. + pub fn fit( + &mut self, + probs: &[f64], + labels: &[f64], + query_features: &[f64], + m: usize, + query_ids: Option<&[usize]>, + learning_rate: f64, + max_iterations: usize, + tolerance: f64, + ) { + let n = self.n_signals; + let nqf = self.n_query_features; + let scale = (n as f64).powf(self.alpha); + + // Compute logits of input signals + let mut x: Vec = probs.iter().map(|&p| logit(safe_prob(p))).collect(); + + if self.normalize { + if let Some(qids) = query_ids { + // Per-query group normalization + let mut unique_ids: Vec = qids.to_vec(); + unique_ids.sort_unstable(); + unique_ids.dedup(); + for &qid in &unique_ids { + let indices: Vec = (0..m).filter(|&i| qids[i] == qid).collect(); + let group_m = indices.len(); + for col in 0..n { + let column: Vec = indices.iter().map(|&i| x[i * n + col]).collect(); + let normalized = min_max_normalize(&column); + for (idx, &i) in indices.iter().enumerate() { + x[i * n + col] = normalized[idx]; + } + } + let _ = group_m; + } + } else { + self.normalize_logits_columns(&mut x, m); + } + } + + for _ in 0..max_iterations { + // Compute per-sample attention weights: z = qf @ W^T + b + let mut z = vec![0.0; m * n]; + for row in 0..m { + for col in 0..n { + let mut val = self.bias[col]; + for k in 0..nqf { + val += query_features[row * nqf + k] * self.w_matrix[col * nqf + k]; + } + z[row * n + col] = val; + } + } + let w = softmax_rows(&z, n); + + // Compute predictions and gradients + let mut grad_w = vec![0.0; n * nqf]; + let mut grad_b = vec![0.0; n]; + + for i in 0..m { + let x_bar_w: f64 = (0..n).map(|j| w[i * n + j] * x[i * n + j]).sum(); + let p = sigmoid(scale * x_bar_w); + let error = p - labels[i]; + + // grad_z_j = scale * error * w_j * (x_j - x_bar_w) + for j in 0..n { + let gz = scale * error * w[i * n + j] * (x[i * n + j] - x_bar_w); + // dL/dW_jk = gz * qf_k + for k in 0..nqf { + grad_w[j * nqf + k] += gz * query_features[i * nqf + k]; + } + grad_b[j] += gz; + } + } + + // Average over samples + let m_f = m as f64; + let old_w = self.w_matrix.clone(); + let old_b = self.bias.clone(); + + for idx in 0..grad_w.len() { + grad_w[idx] /= m_f; + self.w_matrix[idx] -= learning_rate * grad_w[idx]; + } + for j in 0..n { + grad_b[j] /= m_f; + self.bias[j] -= learning_rate * grad_b[j]; + } + + // Check convergence + let max_change_w = old_w + .iter() + .zip(self.w_matrix.iter()) + .map(|(&a, &b)| (a - b).abs()) + .fold(0.0_f64, f64::max); + let max_change_b = old_b + .iter() + .zip(self.bias.iter()) + .map(|(&a, &b)| (a - b).abs()) + .fold(0.0_f64, f64::max); + + if max_change_w.max(max_change_b) < tolerance { + break; + } + } + + // Reset online state + self.n_updates = 0; + self.grad_w_ema = vec![0.0; n * nqf]; + self.grad_b_ema = vec![0.0; n]; + self.w_avg = self.w_matrix.clone(); + self.b_avg = self.bias.clone(); + } + + /// Online SGD update from a single observation or mini-batch. + pub fn update( + &mut self, + probs: &[f64], + labels: &[f64], + query_features: &[f64], + m: usize, + learning_rate: f64, + momentum: f64, + decay_tau: f64, + max_grad_norm: f64, + avg_decay: f64, + ) { + let n = self.n_signals; + let nqf = self.n_query_features; + let scale = (n as f64).powf(self.alpha); + + let mut x: Vec = probs.iter().map(|&p| logit(safe_prob(p))).collect(); + + if self.normalize && m > 1 { + self.normalize_logits_columns(&mut x, m); + } + + // Compute attention weights + let mut z = vec![0.0; m * n]; + for row in 0..m { + for col in 0..n { + let mut val = self.bias[col]; + for k in 0..nqf { + val += query_features[row * nqf + k] * self.w_matrix[col * nqf + k]; + } + z[row * n + col] = val; + } + } + let w = softmax_rows(&z, n); + + // Compute gradients + let mut grad_w = vec![0.0; n * nqf]; + let mut grad_b = vec![0.0; n]; + + for i in 0..m { + let x_bar_w: f64 = (0..n).map(|j| w[i * n + j] * x[i * n + j]).sum(); + let p = sigmoid(scale * x_bar_w); + let error = p - labels[i]; + + for j in 0..n { + let gz = scale * error * w[i * n + j] * (x[i * n + j] - x_bar_w); + for k in 0..nqf { + grad_w[j * nqf + k] += gz * query_features[i * nqf + k]; + } + grad_b[j] += gz; + } + } + + let m_f = m as f64; + for idx in 0..grad_w.len() { + grad_w[idx] /= m_f; + } + for j in 0..n { + grad_b[j] /= m_f; + } + + // EMA smoothing + for idx in 0..grad_w.len() { + self.grad_w_ema[idx] = + momentum * self.grad_w_ema[idx] + (1.0 - momentum) * grad_w[idx]; + } + for j in 0..n { + self.grad_b_ema[j] = + momentum * self.grad_b_ema[j] + (1.0 - momentum) * grad_b[j]; + } + + // Bias correction + self.n_updates += 1; + let correction = 1.0 - momentum.powi(self.n_updates as i32); + let mut corrected_w: Vec = self.grad_w_ema.iter().map(|&g| g / correction).collect(); + let mut corrected_b: Vec = self.grad_b_ema.iter().map(|&g| g / correction).collect(); + + // L2 gradient clipping (joint norm) + let grad_norm = (corrected_w.iter().map(|&g| g * g).sum::() + + corrected_b.iter().map(|&g| g * g).sum::()) + .sqrt(); + if grad_norm > max_grad_norm { + let clip_scale = max_grad_norm / grad_norm; + for g in corrected_w.iter_mut() { + *g *= clip_scale; + } + for g in corrected_b.iter_mut() { + *g *= clip_scale; + } + } + + // Learning rate decay + let effective_lr = learning_rate / (1.0 + self.n_updates as f64 / decay_tau); + + for idx in 0..self.w_matrix.len() { + self.w_matrix[idx] -= effective_lr * corrected_w[idx]; + } + for j in 0..n { + self.bias[j] -= effective_lr * corrected_b[j]; + } + + // Polyak averaging + for idx in 0..self.w_matrix.len() { + self.w_avg[idx] = + avg_decay * self.w_avg[idx] + (1.0 - avg_decay) * self.w_matrix[idx]; + } + for j in 0..n { + self.b_avg[j] = avg_decay * self.b_avg[j] + (1.0 - avg_decay) * self.bias[j]; + } + } +} + +/// Simple PRNG-based normal initialization for reproducibility. +/// +/// Uses a linear congruential generator + Box-Muller transform. +fn simple_normal_init(n: usize, scale: f64, seed: u64) -> Vec { + let mut state = seed.wrapping_add(1); + let mut result = Vec::with_capacity(n); + + let pairs = (n + 1) / 2; + for _ in 0..pairs { + // LCG + state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + let u1 = (state >> 11) as f64 / (1u64 << 53) as f64; + let u1 = u1.max(1e-15); // avoid log(0) + + state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407); + let u2 = (state >> 11) as f64 / (1u64 << 53) as f64; + + // Box-Muller + let r = (-2.0 * u1.ln()).sqrt(); + let theta = 2.0 * std::f64::consts::PI * u2; + result.push(r * theta.cos() * scale); + result.push(r * theta.sin() * scale); + } + + result.truncate(n); + result +} diff --git a/src/bayesian_scorer.rs b/src/bayesian_scorer.rs index 87b0dd3..077d940 100644 --- a/src/bayesian_scorer.rs +++ b/src/bayesian_scorer.rs @@ -2,17 +2,30 @@ use std::rc::Rc; use crate::bm25_scorer::BM25Scorer; use crate::corpus::Document; -use crate::math_utils::{clamp, safe_log, safe_prob, sigmoid}; +use crate::fusion; +use crate::math_utils::{clamp, safe_prob, sigmoid}; pub struct BayesianBM25Scorer { bm25: Rc, alpha: f64, beta: f64, + base_rate: Option, } impl BayesianBM25Scorer { - pub fn new(bm25: Rc, alpha: f64, beta: f64) -> Self { - Self { bm25, alpha, beta } + pub fn new(bm25: Rc, alpha: f64, beta: f64, base_rate: Option) -> Self { + if let Some(br) = base_rate { + assert!( + br > 0.0 && br < 1.0, + "base_rate must be in (0, 1), got {}", + br + ); + } + Self { bm25, alpha, beta, base_rate } + } + + pub fn base_rate(&self) -> Option { + self.base_rate } pub fn likelihood(&self, score: f64) -> f64 { @@ -43,13 +56,29 @@ impl BayesianBM25Scorer { clamp(0.7 * p_tf + 0.3 * p_norm, 0.1, 0.9) } + /// Two-step Bayesian posterior update (Remark 4.4.5). + /// + /// Step 1: Standard Bayes update with likelihood and prior. + /// Step 2 (if base_rate is set): Second Bayes update using base_rate as + /// a corpus-level prior, adjusting the posterior toward the base rate. pub fn posterior(&self, score: f64, prior: f64) -> f64 { - let mut lik = self.likelihood(score); - lik = safe_prob(lik); + let lik = safe_prob(self.likelihood(score)); let prior = safe_prob(prior); + + // Step 1: standard Bayes update let numerator = lik * prior; let denominator = numerator + (1.0 - lik) * (1.0 - prior); - numerator / denominator + let p1 = numerator / denominator; + + // Step 2: base rate adjustment + match self.base_rate { + Some(br) => { + let num2 = p1 * br; + let den2 = num2 + (1.0 - p1) * (1.0 - br); + num2 / den2 + } + None => p1, + } } pub fn score_term(&self, term: &str, doc: &Document) -> f64 { @@ -63,22 +92,16 @@ impl BayesianBM25Scorer { } pub fn score(&self, query_terms: &[String], doc: &Document) -> f64 { - let mut log_complement_sum = 0.0; - let mut has_match = false; - - for term in query_terms { - let p = self.score_term(term, doc); - if p > 0.0 { - has_match = true; - let p = safe_prob(p); - log_complement_sum += safe_log(1.0 - p); - } - } + let posteriors: Vec = query_terms + .iter() + .map(|term| self.score_term(term, doc)) + .filter(|&p| p > 0.0) + .collect(); - if !has_match { + if posteriors.is_empty() { return 0.0; } - 1.0 - log_complement_sum.exp() + fusion::prob_or(&posteriors) } } diff --git a/src/debug.rs b/src/debug.rs new file mode 100644 index 0000000..b309bcb --- /dev/null +++ b/src/debug.rs @@ -0,0 +1,660 @@ +use std::collections::HashMap; + +use crate::fusion::{cosine_to_probability, prob_not}; +use crate::math_utils::{logit, safe_prob, sigmoid}; +use crate::probability::BayesianProbabilityTransform; + +/// Trace of a single BM25 signal through the full probability pipeline. +#[derive(Clone, Debug)] +pub struct BM25SignalTrace { + pub raw_score: f64, + pub tf: f64, + pub doc_len_ratio: f64, + pub likelihood: f64, + pub tf_prior: f64, + pub norm_prior: f64, + pub composite_prior: f64, + pub logit_likelihood: f64, + pub logit_prior: f64, + pub logit_base_rate: Option, + pub posterior: f64, + pub alpha: f64, + pub beta: f64, + pub base_rate: Option, +} + +/// Trace of a cosine similarity through probability conversion. +#[derive(Clone, Debug)] +pub struct VectorSignalTrace { + pub cosine_score: f64, + pub probability: f64, + pub logit_probability: f64, +} + +/// Trace of a probabilistic NOT (complement) operation. +#[derive(Clone, Debug)] +pub struct NotTrace { + pub input_probability: f64, + pub input_name: String, + pub complement: f64, + pub logit_input: f64, + pub logit_complement: f64, +} + +/// Trace of the combination step for multiple probability signals. +#[derive(Clone, Debug)] +pub struct FusionTrace { + pub signal_probabilities: Vec, + pub signal_names: Vec, + pub method: String, + // Log-odds intermediates + pub logits: Option>, + pub mean_logit: Option, + pub alpha: Option, + pub n_alpha_scale: Option, + pub scaled_logit: Option, + pub weights: Option>, + // prob_and intermediates + pub log_probs: Option>, + pub log_prob_sum: Option, + // prob_or/prob_not intermediates + pub complements: Option>, + pub log_complements: Option>, + pub log_complement_sum: Option, + // Output + pub fused_probability: f64, +} + +/// Signal type enum for document traces. +#[derive(Clone, Debug)] +pub enum SignalTrace { + BM25(BM25SignalTrace), + Vector(VectorSignalTrace), +} + +/// Complete trace for one document across all signals and fusion. +#[derive(Clone, Debug)] +pub struct DocumentTrace { + pub doc_id: Option, + pub signals: Vec<(String, SignalTrace)>, + pub fusion: FusionTrace, + pub final_probability: f64, +} + +/// Comparison of two document traces explaining rank differences. +#[derive(Clone, Debug)] +pub struct ComparisonResult { + pub doc_a: DocumentTrace, + pub doc_b: DocumentTrace, + pub signal_deltas: Vec<(String, f64)>, + pub dominant_signal: String, + pub crossover_stage: Option, +} + +/// Traces intermediate values through the Bayesian BM25 fusion pipeline. +pub struct FusionDebugger { + transform: BayesianProbabilityTransform, +} + +impl FusionDebugger { + pub fn new(transform: BayesianProbabilityTransform) -> Self { + Self { transform } + } + + pub fn transform(&self) -> &BayesianProbabilityTransform { + &self.transform + } + + /// Trace a single BM25 score through the full probability pipeline. + pub fn trace_bm25( + &self, + score: f64, + tf: f64, + doc_len_ratio: f64, + ) -> BM25SignalTrace { + let t = &self.transform; + + let likelihood_val = t.likelihood(score); + let tf_prior_val = BayesianProbabilityTransform::tf_prior(tf); + let norm_prior_val = BayesianProbabilityTransform::norm_prior(doc_len_ratio); + let composite_prior_val = BayesianProbabilityTransform::composite_prior(tf, doc_len_ratio); + let posterior_val = BayesianProbabilityTransform::posterior( + likelihood_val, + composite_prior_val, + t.base_rate, + ); + + let logit_likelihood_val = logit(likelihood_val); + let logit_prior_val = logit(composite_prior_val); + let logit_base_rate_val = t.base_rate.map(|br| logit(safe_prob(br))); + + BM25SignalTrace { + raw_score: score, + tf, + doc_len_ratio, + likelihood: likelihood_val, + tf_prior: tf_prior_val, + norm_prior: norm_prior_val, + composite_prior: composite_prior_val, + logit_likelihood: logit_likelihood_val, + logit_prior: logit_prior_val, + logit_base_rate: logit_base_rate_val, + posterior: posterior_val, + alpha: t.alpha, + beta: t.beta, + base_rate: t.base_rate, + } + } + + /// Trace a cosine similarity through probability conversion. + pub fn trace_vector(&self, cosine_score: f64) -> VectorSignalTrace { + let prob_val = cosine_to_probability(cosine_score); + let logit_val = logit(prob_val); + + VectorSignalTrace { + cosine_score, + probability: prob_val, + logit_probability: logit_val, + } + } + + /// Trace a probabilistic NOT (complement) operation. + pub fn trace_not(&self, probability: f64, name: &str) -> NotTrace { + let complement = prob_not(probability); + let logit_in = logit(safe_prob(probability)); + let logit_out = logit(safe_prob(complement)); + + NotTrace { + input_probability: probability, + input_name: name.to_string(), + complement, + logit_input: logit_in, + logit_complement: logit_out, + } + } + + /// Trace the fusion of multiple probability signals. + pub fn trace_fusion( + &self, + probabilities: &[f64], + names: Option<&[String]>, + method: &str, + alpha: Option, + weights: Option<&[f64]>, + ) -> FusionTrace { + let n = probabilities.len(); + let signal_names: Vec = match names { + Some(ns) => ns.to_vec(), + None => (0..n).map(|i| format!("signal_{}", i)).collect(), + }; + + let probs: Vec = probabilities.iter().map(|&p| safe_prob(p)).collect(); + + match method { + "log_odds" => self.trace_log_odds(&probs, &signal_names, alpha, weights), + "prob_and" => self.trace_prob_and(&probs, &signal_names), + "prob_or" => self.trace_prob_or(&probs, &signal_names), + "prob_not" => self.trace_prob_not_fusion(&probs, &signal_names), + _ => panic!("method must be 'log_odds', 'prob_and', 'prob_or', or 'prob_not', got '{}'", method), + } + } + + fn trace_log_odds( + &self, + probs: &[f64], + names: &[String], + alpha: Option, + weights: Option<&[f64]>, + ) -> FusionTrace { + let n = probs.len(); + let logits_arr: Vec = probs.iter().map(|&p| logit(p)).collect(); + + if let Some(w) = weights { + let effective_alpha = alpha.unwrap_or(0.0); + let n_alpha_scale = (n as f64).powf(effective_alpha); + let weighted_logit: f64 = w.iter().zip(logits_arr.iter()).map(|(&wi, &li)| wi * li).sum(); + let scaled = n_alpha_scale * weighted_logit; + let fused = sigmoid(scaled); + + return FusionTrace { + signal_probabilities: probs.to_vec(), + signal_names: names.to_vec(), + method: "log_odds".to_string(), + logits: Some(logits_arr), + mean_logit: Some(weighted_logit), + alpha: Some(effective_alpha), + n_alpha_scale: Some(n_alpha_scale), + scaled_logit: Some(scaled), + weights: Some(w.to_vec()), + log_probs: None, + log_prob_sum: None, + complements: None, + log_complements: None, + log_complement_sum: None, + fused_probability: fused, + }; + } + + let effective_alpha = alpha.unwrap_or(0.5); + let mean_logit_val: f64 = logits_arr.iter().sum::() / n as f64; + let n_alpha_scale = (n as f64).powf(effective_alpha); + let scaled = mean_logit_val * n_alpha_scale; + let fused = sigmoid(scaled); + + FusionTrace { + signal_probabilities: probs.to_vec(), + signal_names: names.to_vec(), + method: "log_odds".to_string(), + logits: Some(logits_arr), + mean_logit: Some(mean_logit_val), + alpha: Some(effective_alpha), + n_alpha_scale: Some(n_alpha_scale), + scaled_logit: Some(scaled), + weights: None, + log_probs: None, + log_prob_sum: None, + complements: None, + log_complements: None, + log_complement_sum: None, + fused_probability: fused, + } + } + + fn trace_prob_and(&self, probs: &[f64], names: &[String]) -> FusionTrace { + let log_probs: Vec = probs.iter().map(|&p| p.ln()).collect(); + let log_sum: f64 = log_probs.iter().sum(); + let fused = log_sum.exp(); + + FusionTrace { + signal_probabilities: probs.to_vec(), + signal_names: names.to_vec(), + method: "prob_and".to_string(), + logits: None, + mean_logit: None, + alpha: None, + n_alpha_scale: None, + scaled_logit: None, + weights: None, + log_probs: Some(log_probs), + log_prob_sum: Some(log_sum), + complements: None, + log_complements: None, + log_complement_sum: None, + fused_probability: fused, + } + } + + fn trace_prob_or(&self, probs: &[f64], names: &[String]) -> FusionTrace { + let comps: Vec = probs.iter().map(|&p| 1.0 - p).collect(); + let log_comps: Vec = comps.iter().map(|&c| c.ln()).collect(); + let log_sum: f64 = log_comps.iter().sum(); + let fused = 1.0 - log_sum.exp(); + + FusionTrace { + signal_probabilities: probs.to_vec(), + signal_names: names.to_vec(), + method: "prob_or".to_string(), + logits: None, + mean_logit: None, + alpha: None, + n_alpha_scale: None, + scaled_logit: None, + weights: None, + log_probs: None, + log_prob_sum: None, + complements: Some(comps), + log_complements: Some(log_comps), + log_complement_sum: Some(log_sum), + fused_probability: fused, + } + } + + fn trace_prob_not_fusion(&self, probs: &[f64], names: &[String]) -> FusionTrace { + let comps: Vec = probs.iter().map(|&p| 1.0 - p).collect(); + let log_comps: Vec = comps.iter().map(|&c| c.ln()).collect(); + let log_sum: f64 = log_comps.iter().sum(); + let fused = log_sum.exp(); + + FusionTrace { + signal_probabilities: probs.to_vec(), + signal_names: names.to_vec(), + method: "prob_not".to_string(), + logits: None, + mean_logit: None, + alpha: None, + n_alpha_scale: None, + scaled_logit: None, + weights: None, + log_probs: None, + log_prob_sum: None, + complements: Some(comps), + log_complements: Some(log_comps), + log_complement_sum: Some(log_sum), + fused_probability: fused, + } + } + + /// Full pipeline trace for one document (convenience method). + pub fn trace_document( + &self, + bm25_score: Option, + tf: Option, + doc_len_ratio: Option, + cosine_score: Option, + method: &str, + alpha: Option, + weights: Option<&[f64]>, + doc_id: Option<&str>, + ) -> DocumentTrace { + let mut signals: Vec<(String, SignalTrace)> = Vec::new(); + let mut probs: Vec = Vec::new(); + let mut names: Vec = Vec::new(); + + if let Some(bm25) = bm25_score { + let tf_val = tf.expect("tf is required when bm25_score is provided"); + let dlr_val = doc_len_ratio.expect("doc_len_ratio is required when bm25_score is provided"); + let trace = self.trace_bm25(bm25, tf_val, dlr_val); + probs.push(trace.posterior); + names.push("BM25".to_string()); + signals.push(("BM25".to_string(), SignalTrace::BM25(trace))); + } + + if let Some(cos) = cosine_score { + let trace = self.trace_vector(cos); + probs.push(trace.probability); + names.push("Vector".to_string()); + signals.push(("Vector".to_string(), SignalTrace::Vector(trace))); + } + + assert!(!probs.is_empty(), "At least one of bm25_score or cosine_score must be provided"); + + let fusion_trace = self.trace_fusion( + &probs, + Some(&names), + method, + alpha, + weights, + ); + + DocumentTrace { + doc_id: doc_id.map(|s| s.to_string()), + signals, + fusion: fusion_trace.clone(), + final_probability: fusion_trace.fused_probability, + } + } + + /// Compare two document traces to explain rank differences. + pub fn compare( + &self, + trace_a: &DocumentTrace, + trace_b: &DocumentTrace, + ) -> ComparisonResult { + // Collect all unique signal names preserving order + let mut all_names: Vec = Vec::new(); + let mut seen = HashMap::new(); + for (name, _) in &trace_a.signals { + if seen.insert(name.clone(), ()).is_none() { + all_names.push(name.clone()); + } + } + for (name, _) in &trace_b.signals { + if seen.insert(name.clone(), ()).is_none() { + all_names.push(name.clone()); + } + } + + let mut signal_deltas: Vec<(String, f64)> = Vec::new(); + for name in &all_names { + let prob_a = signal_probability(trace_a, name); + let prob_b = signal_probability(trace_b, name); + signal_deltas.push((name.clone(), prob_a - prob_b)); + } + + // Dominant signal: largest absolute delta + let dominant = signal_deltas + .iter() + .max_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap()) + .map(|(name, _)| name.clone()) + .unwrap_or_default(); + + // Crossover detection + let fused_delta = trace_a.final_probability - trace_b.final_probability; + let mut crossover_stage: Option = None; + for (name, delta) in &signal_deltas { + if name == &dominant { + continue; + } + if fused_delta != 0.0 + && *delta != 0.0 + && ((fused_delta > 0.0 && *delta < 0.0) || (fused_delta < 0.0 && *delta > 0.0)) + { + crossover_stage = Some(name.clone()); + break; + } + } + + ComparisonResult { + doc_a: trace_a.clone(), + doc_b: trace_b.clone(), + signal_deltas, + dominant_signal: dominant, + crossover_stage, + } + } + + /// Format a document trace as human-readable text. + pub fn format_trace(&self, trace: &DocumentTrace, verbose: bool) -> String { + let mut lines: Vec = Vec::new(); + let doc_label = trace.doc_id.as_deref().unwrap_or("unknown"); + lines.push(format!("Document: {}", doc_label)); + + for (name, sig) in &trace.signals { + match sig { + SignalTrace::BM25(s) => { + lines.push(format!( + " [{}] raw={:.2} -> likelihood={:.3} (alpha={:.2}, beta={:.2})", + name, s.raw_score, s.likelihood, s.alpha, s.beta + )); + lines.push(format!(" tf={:.0} -> tf_prior={:.3}", s.tf, s.tf_prior)); + lines.push(format!( + " dl_ratio={:.2} -> norm_prior={:.3}", + s.doc_len_ratio, s.norm_prior + )); + lines.push(format!(" composite_prior={:.3}", s.composite_prior)); + if let Some(br) = s.base_rate { + let posterior_no_br = BayesianProbabilityTransform::posterior( + s.likelihood, + s.composite_prior, + None, + ); + lines.push(format!(" posterior={:.3}", posterior_no_br)); + lines.push(format!( + " with base_rate={:.3}: posterior={:.3}", + br, s.posterior + )); + } else { + lines.push(format!(" posterior={:.3}", s.posterior)); + } + if verbose { + lines.push(format!( + " logit(posterior)={:.3}", + logit(safe_prob(s.posterior)) + )); + } + lines.push(String::new()); + } + SignalTrace::Vector(s) => { + lines.push(format!( + " [{}] cosine={:.3} -> prob={:.3}", + name, s.cosine_score, s.probability + )); + if verbose { + lines.push(format!(" logit(prob)={:.3}", s.logit_probability)); + } + lines.push(String::new()); + } + } + } + + // Fusion + let f = &trace.fusion; + let alpha_str = f.alpha.map_or(String::new(), |a| format!(", alpha={}", a)); + let n_str = format!(", n={}", f.signal_probabilities.len()); + lines.push(format!(" [Fusion] method={}{}{}", f.method, alpha_str, n_str)); + + if verbose { + if let Some(ref logits) = f.logits { + let s: Vec = logits.iter().map(|v| format!("{:.3}", v)).collect(); + lines.push(format!(" logits=[{}]", s.join(", "))); + } + if let Some(ml) = f.mean_logit { + lines.push(format!(" mean_logit={:.3}", ml)); + } + if let (Some(nas), Some(sl)) = (f.n_alpha_scale, f.scaled_logit) { + lines.push(format!(" n^alpha={:.3}, scaled={:.3}", nas, sl)); + } + if let Some(ref w) = f.weights { + let s: Vec = w.iter().map(|v| format!("{:.3}", v)).collect(); + lines.push(format!(" weights=[{}]", s.join(", "))); + } + if let Some(ref lp) = f.log_probs { + let s: Vec = lp.iter().map(|v| format!("{:.3}", v)).collect(); + lines.push(format!(" ln(P)=[{}]", s.join(", "))); + if let Some(lps) = f.log_prob_sum { + lines.push(format!(" sum(ln(P))={:.3}", lps)); + } + } + if let Some(ref c) = f.complements { + let s: Vec = c.iter().map(|v| format!("{:.3}", v)).collect(); + lines.push(format!(" 1-P=[{}]", s.join(", "))); + } + if let Some(ref lc) = f.log_complements { + let s: Vec = lc.iter().map(|v| format!("{:.3}", v)).collect(); + lines.push(format!(" ln(1-P)=[{}]", s.join(", "))); + if let Some(lcs) = f.log_complement_sum { + lines.push(format!(" sum(ln(1-P))={:.3}", lcs)); + } + } + } + + lines.push(format!(" -> final={:.3}", f.fused_probability)); + lines.join("\n") + } + + /// Compact one-line summary of a document trace. + pub fn format_summary(&self, trace: &DocumentTrace) -> String { + let doc_label = trace.doc_id.as_deref().unwrap_or("unknown"); + let mut parts: Vec = Vec::new(); + for (_, sig) in &trace.signals { + match sig { + SignalTrace::BM25(s) => parts.push(format!("BM25={:.3}", s.posterior)), + SignalTrace::Vector(s) => parts.push(format!("Vec={:.3}", s.probability)), + } + } + + let f = &trace.fusion; + let alpha_str = f.alpha.map_or(String::new(), |a| format!(", alpha={}", a)); + format!( + "{}: {} -> Fused={:.3} ({}{})", + doc_label, + parts.join(" "), + f.fused_probability, + f.method, + alpha_str + ) + } + + /// Format a comparison result as human-readable text. + pub fn format_comparison(&self, comparison: &ComparisonResult) -> String { + let a = &comparison.doc_a; + let b = &comparison.doc_b; + let a_label = a.doc_id.as_deref().unwrap_or("doc_a"); + let b_label = b.doc_id.as_deref().unwrap_or("doc_b"); + + let mut lines: Vec = Vec::new(); + lines.push(format!("Comparison: {} vs {}", a_label, b_label)); + + lines.push(format!( + " {:<12} {:>8} {:>8} {:>8} dominant", + "Signal", a_label, b_label, "delta" + )); + + for (name, delta) in &comparison.signal_deltas { + let prob_a = signal_probability(a, name); + let prob_b = signal_probability(b, name); + let dominant_marker = if name == &comparison.dominant_signal { + " <-- largest" + } else { + "" + }; + lines.push(format!( + " {:<12} {:>8.3} {:>8.3} {:>+8.3}{}", + name, prob_a, prob_b, delta, dominant_marker + )); + } + + let fused_delta = a.final_probability - b.final_probability; + lines.push(format!( + " {:<12} {:>8.3} {:>8.3} {:>+8.3}", + "Fused", a.final_probability, b.final_probability, fused_delta + )); + lines.push(String::new()); + + if fused_delta > 0.0 { + lines.push(format!( + " Rank order: {} > {} (by {:+.3})", + a_label, b_label, fused_delta + )); + } else if fused_delta < 0.0 { + lines.push(format!( + " Rank order: {} > {} (by +{:.3})", + b_label, a_label, fused_delta.abs() + )); + } else { + lines.push(" Rank order: tied".to_string()); + } + + let dom = &comparison.dominant_signal; + let dom_delta = comparison + .signal_deltas + .iter() + .find(|(n, _)| n == dom) + .map(|(_, d)| *d) + .unwrap_or(0.0); + let favored = if dom_delta >= 0.0 { a_label } else { b_label }; + lines.push(format!( + " Dominant signal: {} ({:+.3} in {}'s favor)", + dom, dom_delta, favored + )); + + if let Some(ref cross) = comparison.crossover_stage { + let cross_delta = comparison + .signal_deltas + .iter() + .find(|(n, _)| n == cross) + .map(|(_, d)| *d) + .unwrap_or(0.0); + let cross_favored = if cross_delta >= 0.0 { a_label } else { b_label }; + lines.push(format!( + " Note: {} favored {}, but {} signal outweighed it", + cross, cross_favored, dom + )); + } + + lines.join("\n") + } +} + +/// Extract the final probability from a signal within a document trace. +fn signal_probability(trace: &DocumentTrace, name: &str) -> f64 { + for (n, sig) in &trace.signals { + if n == name { + return match sig { + SignalTrace::BM25(s) => s.posterior, + SignalTrace::Vector(s) => s.probability, + }; + } + } + 0.5 // neutral if signal missing +} diff --git a/src/experiments.rs b/src/experiments.rs index 68dd4e2..3b66e32 100644 --- a/src/experiments.rs +++ b/src/experiments.rs @@ -4,6 +4,7 @@ use std::rc::Rc; use crate::bayesian_scorer::BayesianBM25Scorer; use crate::bm25_scorer::BM25Scorer; use crate::corpus::{Corpus, Document}; +use crate::fusion; use crate::hybrid_scorer::HybridScorer; use crate::math_utils::{safe_log, sigmoid, EPSILON}; use crate::parameter_learner::ParameterLearner; @@ -45,7 +46,7 @@ pub struct ExperimentRunner { impl ExperimentRunner { pub fn new(corpus: Rc, queries: Vec, k1: f64, b: f64) -> Self { let bm25 = Rc::new(BM25Scorer::new(Rc::clone(&corpus), k1, b)); - let bayesian = Rc::new(BayesianBM25Scorer::new(Rc::clone(&bm25), 1.0, 0.5)); + let bayesian = Rc::new(BayesianBM25Scorer::new(Rc::clone(&bm25), 1.0, 0.5, None)); let vector = Rc::new(VectorScorer::new()); let hybrid = HybridScorer::new(Rc::clone(&bayesian), Rc::clone(&vector), 0.5); @@ -71,6 +72,9 @@ impl ExperimentRunner { ("8. Log-space Numerical Stability", ExperimentRunner::exp8_numerical_stability), ("9. Parameter Learning Convergence", ExperimentRunner::exp9_parameter_learning), ("10. Conjunction/Disjunction Bounds", ExperimentRunner::exp10_conjunction_disjunction), + ("11. Base Rate Prior", ExperimentRunner::exp11_base_rate_prior), + ("12. Log-Odds Conjunction Properties", ExperimentRunner::exp12_log_odds_conjunction), + ("13. Fusion Primitives", ExperimentRunner::exp13_fusion_primitives), ]; experiments @@ -670,4 +674,228 @@ impl ExperimentRunner { (passed, detail) } + + fn exp11_base_rate_prior(&self) -> (bool, String) { + let mut passed = true; + let mut details: Vec = Vec::new(); + + let query = &self.queries[0]; + + // Scorer without base_rate (default) + let scorer_none = BayesianBM25Scorer::new(Rc::clone(&self.bm25), 1.0, 0.5, None); + + // Low base_rate should reduce posteriors + let scorer_low = BayesianBM25Scorer::new(Rc::clone(&self.bm25), 1.0, 0.5, Some(0.01)); + + // base_rate=0.5 should be neutral (no change from step 1) + let scorer_neutral = BayesianBM25Scorer::new(Rc::clone(&self.bm25), 1.0, 0.5, Some(0.5)); + + let mut low_reduces = true; + let mut neutral_ok = true; + let mut all_in_range = true; + + for doc in self.corpus.documents() { + let score_none = scorer_none.score(&query.terms, doc); + let score_low = scorer_low.score(&query.terms, doc); + let score_neutral = scorer_neutral.score(&query.terms, doc); + + // Low base_rate should reduce posteriors vs None + if score_none > EPSILON && score_low > score_none + EPSILON { + low_reduces = false; + details.push(format!( + "doc={}: low={:.6} > none={:.6}", + doc.id, score_low, score_none + )); + } + + // base_rate=0.5 should be approximately neutral + if (score_neutral - score_none).abs() > 1e-4 { + neutral_ok = false; + details.push(format!( + "doc={}: neutral={:.6} != none={:.6}", + doc.id, score_neutral, score_none + )); + } + + // All results must be in (0, 1) + for &s in &[score_none, score_low, score_neutral] { + if s < -EPSILON || s > 1.0 + EPSILON { + all_in_range = false; + } + } + } + + if !low_reduces { + passed = false; + } + if !neutral_ok { + passed = false; + } + if !all_in_range { + passed = false; + } + + let mut detail = format!( + "low_reduces={}, neutral_ok={}, all_in_range={}", + low_reduces, neutral_ok, all_in_range + ); + if !details.is_empty() { + let preview = details.into_iter().take(3).collect::>().join("; "); + detail.push_str(", violations: "); + detail.push_str(&preview); + } + + (passed, detail) + } + + fn exp12_log_odds_conjunction(&self) -> (bool, String) { + let mut passed = true; + let mut violations: Vec = Vec::new(); + + // Agreement amplification: conj([0.9, 0.9]) > 0.9 + let result = fusion::log_odds_conjunction(&[0.9, 0.9], None, None, fusion::Gating::NoGating); + if result <= 0.9 { + passed = false; + violations.push(format!("agreement: conj([0.9,0.9])={:.6} <= 0.9", result)); + } + + // Disagreement moderation: conj([0.9, 0.1]) should be near 0.5 + let result = fusion::log_odds_conjunction(&[0.9, 0.1], None, None, fusion::Gating::NoGating); + if (result - 0.5).abs() > 0.15 { + passed = false; + violations.push(format!("disagreement: conj([0.9,0.1])={:.6} not near 0.5", result)); + } + + // Neutral identity: conj([0.5, 0.5]) = 0.5 + let result = fusion::log_odds_conjunction(&[0.5, 0.5], None, None, fusion::Gating::NoGating); + if (result - 0.5).abs() > EPSILON { + passed = false; + violations.push(format!("neutral: conj([0.5,0.5])={:.6} != 0.5", result)); + } + + // More signals amplify: conj([0.8]*3) > conj([0.8]*2) + let conj2 = fusion::log_odds_conjunction(&[0.8, 0.8], None, None, fusion::Gating::NoGating); + let conj3 = fusion::log_odds_conjunction(&[0.8, 0.8, 0.8], None, None, fusion::Gating::NoGating); + if conj3 <= conj2 { + passed = false; + violations.push(format!( + "amplification: conj([0.8]*3)={:.6} <= conj([0.8]*2)={:.6}", + conj3, conj2 + )); + } + + // Alpha effect: higher alpha = stronger amplification + let low_alpha = fusion::log_odds_conjunction(&[0.8, 0.8], Some(0.3), None, fusion::Gating::NoGating); + let high_alpha = fusion::log_odds_conjunction(&[0.8, 0.8], Some(0.8), None, fusion::Gating::NoGating); + if high_alpha <= low_alpha { + passed = false; + violations.push(format!( + "alpha: high_alpha={:.6} <= low_alpha={:.6}", + high_alpha, low_alpha + )); + } + + // Weighted with uniform weights + alpha=0 matches unweighted alpha=0 + let uniform_w = vec![0.5, 0.5]; + let probs = [0.7, 0.8]; + let weighted = fusion::log_odds_conjunction(&probs, Some(0.0), Some(&uniform_w), fusion::Gating::NoGating); + let unweighted = fusion::log_odds_conjunction(&probs, Some(0.0), None, fusion::Gating::NoGating); + if (weighted - unweighted).abs() > 1e-6 { + passed = false; + violations.push(format!( + "uniform_weights: weighted={:.6} != unweighted={:.6}", + weighted, unweighted + )); + } + + let detail = if violations.is_empty() { + "all properties verified".to_string() + } else { + let preview = violations.into_iter().take(3).collect::>().join("; "); + format!("violations: {}", preview) + }; + + (passed, detail) + } + + fn exp13_fusion_primitives(&self) -> (bool, String) { + let mut passed = true; + let mut violations: Vec = Vec::new(); + + // prob_not: involution (not(not(p)) = p) + for &p in &[0.1, 0.3, 0.5, 0.7, 0.9] { + let roundtrip = fusion::prob_not(fusion::prob_not(p)); + if (roundtrip - p).abs() > 1e-8 { + passed = false; + violations.push(format!("involution: not(not({:.1}))={:.6} != {:.1}", p, roundtrip, p)); + } + } + + // prob_not: bounds + let not_low = fusion::prob_not(0.01); + let not_high = fusion::prob_not(0.99); + if not_low < 0.0 || not_low > 1.0 || not_high < 0.0 || not_high > 1.0 { + passed = false; + violations.push("prob_not out of bounds".to_string()); + } + + // De Morgan's law: not(and(p1, p2)) = or(not(p1), not(p2)) + for &(p1, p2) in &[(0.3, 0.7), (0.1, 0.9), (0.5, 0.5)] { + let lhs = fusion::prob_not(fusion::prob_and(&[p1, p2])); + let rhs = fusion::prob_or(&[fusion::prob_not(p1), fusion::prob_not(p2)]); + if (lhs - rhs).abs() > 1e-8 { + passed = false; + violations.push(format!( + "de_morgan: not(and({:.1},{:.1}))={:.6} != or(not,not)={:.6}", + p1, p2, lhs, rhs + )); + } + } + + // balanced_log_odds_fusion: output dimension + let sparse = vec![0.3, 0.6, 0.8]; + let dense = vec![0.1, 0.5, 0.9]; + let fused = fusion::balanced_log_odds_fusion(&sparse, &dense, 0.5); + if fused.len() != sparse.len() { + passed = false; + violations.push(format!("fusion dim: {} != {}", fused.len(), sparse.len())); + } + + // balanced_log_odds_fusion: weight effect (weight=1.0 should favor dense) + let fused_dense = fusion::balanced_log_odds_fusion(&sparse, &dense, 1.0); + let fused_sparse = fusion::balanced_log_odds_fusion(&sparse, &dense, 0.0); + if fused_dense == fused_sparse { + passed = false; + violations.push("weight has no effect".to_string()); + } + + // cosine_to_probability: bounds + for &s in &[-1.0, -0.5, 0.0, 0.5, 1.0] { + let p = fusion::cosine_to_probability(s); + if p <= 0.0 || p >= 1.0 { + passed = false; + violations.push(format!("cos_to_prob({:.1})={:.6} out of (0,1)", s, p)); + } + } + + // cosine_to_probability: monotonicity + let mut prev = fusion::cosine_to_probability(-1.0); + for &s in &[-0.5, 0.0, 0.5, 1.0] { + let curr = fusion::cosine_to_probability(s); + if curr < prev - EPSILON { + passed = false; + violations.push(format!("cos_to_prob not monotonic at {:.1}", s)); + } + prev = curr; + } + + let detail = if violations.is_empty() { + "all primitives verified".to_string() + } else { + let preview = violations.into_iter().take(3).collect::>().join("; "); + format!("violations: {}", preview) + }; + + (passed, detail) + } } diff --git a/src/fusion.rs b/src/fusion.rs new file mode 100644 index 0000000..f16b939 --- /dev/null +++ b/src/fusion.rs @@ -0,0 +1,144 @@ +use crate::math_utils::{logit, min_max_normalize, safe_prob, sigmoid}; + +/// Gating function for sparse signal logits before aggregation. +#[derive(Clone, Copy, Debug, Default, PartialEq)] +pub enum Gating { + /// No gating (pass-through). + #[default] + NoGating, + /// MAP estimate under sparse prior (Theorem 6.5.3): max(0, logit). + Relu, + /// Bayes estimate under sparse prior (Theorem 6.7.4): logit * sigmoid(logit). + Swish, +} + +/// Apply gating function to a single logit value. +fn apply_gating(logit_val: f64, gating: Gating) -> f64 { + match gating { + Gating::NoGating => logit_val, + Gating::Relu => logit_val.max(0.0), + Gating::Swish => logit_val * sigmoid(logit_val), + } +} + +/// Maps cosine similarity [-1, 1] to probability (0, 1). +pub fn cosine_to_probability(score: f64) -> f64 { + safe_prob((1.0 + score) / 2.0) +} + +/// Probabilistic complement: P(not A) = 1 - P(A). +pub fn prob_not(prob: f64) -> f64 { + safe_prob(1.0 - safe_prob(prob)) +} + +/// Probabilistic AND via product rule in log-space. +/// +/// P(A1 AND A2 AND ... AND An) = product(p_i), computed as exp(sum(ln(p_i))) +/// for numerical stability. +pub fn prob_and(probs: &[f64]) -> f64 { + let log_sum: f64 = probs.iter().map(|&p| safe_prob(p).ln()).sum(); + log_sum.exp() +} + +/// Probabilistic OR via complement rule in log-space. +/// +/// P(A1 OR A2 OR ... OR An) = 1 - product(1 - p_i), computed as +/// 1 - exp(sum(ln(1 - p_i))) for numerical stability. +pub fn prob_or(probs: &[f64]) -> f64 { + let log_complement_sum: f64 = probs + .iter() + .map(|&p| (1.0 - safe_prob(p)).ln()) + .sum(); + 1.0 - log_complement_sum.exp() +} + +/// Log-odds conjunction (paper Eq. 20/23, Theorem 8.3). +/// +/// Unweighted (weights=None): +/// sigmoid(mean(logit(p_i)) * n^alpha) +/// Default alpha = 0.5 +/// +/// Weighted (weights=Some): +/// sigmoid(n^alpha * sum(w_i * logit(p_i))) +/// Default alpha = 0.0 +/// Requires: all w_i >= 0, sum(w_i) = 1.0 +/// +/// Gating is applied to logit values before aggregation: +/// NoGating: pass-through +/// Relu: max(0, logit) -- MAP under sparse prior (Theorem 6.5.3) +/// Swish: logit * sigmoid(logit) -- Bayes under sparse prior (Theorem 6.7.4) +pub fn log_odds_conjunction( + probs: &[f64], + alpha: Option, + weights: Option<&[f64]>, + gating: Gating, +) -> f64 { + if probs.is_empty() { + return 0.5; + } + let n = probs.len() as f64; + + // Compute gated logits + let gated_logits: Vec = probs + .iter() + .map(|&p| apply_gating(logit(safe_prob(p)), gating)) + .collect(); + + match weights { + None => { + let effective_alpha = alpha.unwrap_or(0.5); + let l_bar: f64 = gated_logits.iter().sum::() / n; + sigmoid(l_bar * n.powf(effective_alpha)) + } + Some(w) => { + assert_eq!( + w.len(), + probs.len(), + "weights length must match probs length" + ); + assert!( + w.iter().all(|&wi| wi >= 0.0), + "all weights must be non-negative" + ); + assert!( + (w.iter().sum::() - 1.0).abs() < 1e-6, + "weights must sum to 1.0" + ); + + let effective_alpha = alpha.unwrap_or(0.0); + let weighted_logit_sum: f64 = gated_logits + .iter() + .zip(w.iter()) + .map(|(&l, &wi)| wi * l) + .sum(); + sigmoid(n.powf(effective_alpha) * weighted_logit_sum) + } + } +} + +/// Balanced log-odds fusion for hybrid sparse-dense retrieval. +/// +/// Converts both score vectors to logit-space, min-max normalizes each, +/// then linearly blends: weight * dense_norm + (1 - weight) * sparse_norm. +pub fn balanced_log_odds_fusion( + sparse_probs: &[f64], + dense_similarities: &[f64], + weight: f64, +) -> Vec { + let n = sparse_probs.len(); + let logit_sparse: Vec = sparse_probs + .iter() + .map(|&p| logit(safe_prob(p))) + .collect(); + let logit_dense: Vec = dense_similarities + .iter() + .map(|&s| logit(cosine_to_probability(s))) + .collect(); + + let sparse_norm = min_max_normalize(&logit_sparse); + let dense_norm = min_max_normalize(&logit_dense); + + (0..n) + .map(|i| weight * dense_norm[i] + (1.0 - weight) * sparse_norm[i]) + .collect() +} diff --git a/src/hybrid_scorer.rs b/src/hybrid_scorer.rs index 5168bd4..ac7eeb1 100644 --- a/src/hybrid_scorer.rs +++ b/src/hybrid_scorer.rs @@ -1,9 +1,10 @@ use std::rc::Rc; use crate::bayesian_scorer::BayesianBM25Scorer; -use crate::math_utils::{logit, safe_log, safe_prob, sigmoid, EPSILON}; -use crate::vector_scorer::VectorScorer; use crate::corpus::Document; +use crate::fusion; +use crate::math_utils::EPSILON; +use crate::vector_scorer::VectorScorer; pub struct HybridScorer { bayesian: Rc, @@ -17,36 +18,11 @@ impl HybridScorer { } pub fn probabilistic_and(&self, probs: &[f64]) -> f64 { - if probs.is_empty() { - return 0.0; - } - let n = probs.len(); - if n == 1 { - return safe_prob(probs[0]); - } - - // Stage 1: Geometric mean in log-space - let mut log_sum = 0.0; - for p in probs { - let p = safe_prob(*p); - log_sum += safe_log(p); - } - let geo_mean = (log_sum / n as f64).exp(); - - // Stage 2: Log-odds transformation with agreement bonus - let l_adjusted = logit(geo_mean) + self.alpha * (n as f64).ln(); - - // Stage 3: Return to probability space - sigmoid(l_adjusted) + fusion::log_odds_conjunction(probs, Some(self.alpha), None, fusion::Gating::NoGating) } pub fn probabilistic_or(&self, probs: &[f64]) -> f64 { - let mut log_complement_sum = 0.0; - for p in probs { - let p = safe_prob(*p); - log_complement_sum += safe_log(1.0 - p); - } - 1.0 - log_complement_sum.exp() + fusion::prob_or(probs) } pub fn score_and( diff --git a/src/learnable_weights.rs b/src/learnable_weights.rs new file mode 100644 index 0000000..81e0411 --- /dev/null +++ b/src/learnable_weights.rs @@ -0,0 +1,192 @@ +use crate::fusion::{log_odds_conjunction, Gating}; +use crate::math_utils::{logit, safe_prob, sigmoid, softmax}; + +/// Learnable per-signal reliability weights for log-odds conjunction (Remark 5.3.2). +/// +/// Learns weights that map from the Naive Bayes uniform initialization +/// (w_i = 1/n) to per-signal reliability weights via softmax parameterization. +/// +/// The gradient dL/dz_j = n^alpha * (p - y) * w_j * (x_j - x_bar_w) +/// is Hebbian: the product of pre-synaptic activity (signal deviation +/// from weighted mean) and post-synaptic error (prediction minus label). +pub struct LearnableLogOddsWeights { + n_signals: usize, + alpha: f64, + logits: Vec, + n_updates: usize, + grad_logits_ema: Vec, + weights_avg: Vec, +} + +impl LearnableLogOddsWeights { + pub fn new(n_signals: usize, alpha: f64) -> Self { + assert!(n_signals >= 1, "n_signals must be >= 1, got {}", n_signals); + let uniform = 1.0 / n_signals as f64; + Self { + n_signals, + alpha, + logits: vec![0.0; n_signals], + n_updates: 0, + grad_logits_ema: vec![0.0; n_signals], + weights_avg: vec![uniform; n_signals], + } + } + + pub fn n_signals(&self) -> usize { + self.n_signals + } + + pub fn alpha(&self) -> f64 { + self.alpha + } + + /// Current weights: softmax of internal logits. + pub fn weights(&self) -> Vec { + softmax(&self.logits) + } + + /// Polyak-averaged weights for stable inference. + pub fn averaged_weights(&self) -> Vec { + self.weights_avg.clone() + } + + /// Combine probability signals via weighted log-odds conjunction. + pub fn combine(&self, probs: &[f64], use_averaged: bool) -> f64 { + let w = if use_averaged { + self.weights_avg.clone() + } else { + self.weights() + }; + log_odds_conjunction(probs, Some(self.alpha), Some(&w), Gating::NoGating) + } + + /// Batch gradient descent on BCE loss to learn weights. + pub fn fit( + &mut self, + probs: &[Vec], + labels: &[f64], + learning_rate: f64, + max_iterations: usize, + tolerance: f64, + ) { + let m = probs.len(); + let n = self.n_signals; + let scale = (n as f64).powf(self.alpha); + + // Precompute log-odds of input signals + let x: Vec> = probs + .iter() + .map(|row| { + assert_eq!(row.len(), n, "probs row length {} != n_signals {}", row.len(), n); + row.iter().map(|&p| logit(safe_prob(p))).collect() + }) + .collect(); + + for _ in 0..max_iterations { + let w = softmax(&self.logits); + + let mut grad_logits = vec![0.0; n]; + + for i in 0..m { + // Weighted mean log-odds + let x_bar_w: f64 = w.iter().zip(x[i].iter()).map(|(&wj, &xj)| wj * xj).sum(); + + // Predicted probability + let p = sigmoid(scale * x_bar_w); + let error = p - labels[i]; + + // Gradient for each logit z_j + for j in 0..n { + grad_logits[j] += scale * error * w[j] * (x[i][j] - x_bar_w); + } + } + + // Average over samples + let mut max_change = 0.0_f64; + for j in 0..n { + grad_logits[j] /= m as f64; + let delta = learning_rate * grad_logits[j]; + self.logits[j] -= delta; + max_change = max_change.max(delta.abs()); + } + + if max_change < tolerance { + break; + } + } + + // Reset online state + self.n_updates = 0; + self.grad_logits_ema = vec![0.0; n]; + self.weights_avg = softmax(&self.logits); + } + + /// Online SGD update from a single observation or mini-batch. + pub fn update( + &mut self, + probs: &[Vec], + labels: &[f64], + learning_rate: f64, + momentum: f64, + decay_tau: f64, + max_grad_norm: f64, + avg_decay: f64, + ) { + let m = probs.len(); + let n = self.n_signals; + let scale = (n as f64).powf(self.alpha); + let w = softmax(&self.logits); + + let mut grad_logits = vec![0.0; n]; + + for i in 0..m { + assert_eq!(probs[i].len(), n); + let x: Vec = probs[i].iter().map(|&p| logit(safe_prob(p))).collect(); + let x_bar_w: f64 = w.iter().zip(x.iter()).map(|(&wj, &xj)| wj * xj).sum(); + let p = sigmoid(scale * x_bar_w); + let error = p - labels[i]; + + for j in 0..n { + grad_logits[j] += scale * error * w[j] * (x[j] - x_bar_w); + } + } + + // Average over mini-batch + for j in 0..n { + grad_logits[j] /= m as f64; + } + + // EMA smoothing + for j in 0..n { + self.grad_logits_ema[j] = + momentum * self.grad_logits_ema[j] + (1.0 - momentum) * grad_logits[j]; + } + + // Bias correction + self.n_updates += 1; + let correction = 1.0 - momentum.powi(self.n_updates as i32); + let mut corrected: Vec = self.grad_logits_ema.iter().map(|&g| g / correction).collect(); + + // L2 gradient clipping + let grad_norm: f64 = corrected.iter().map(|&g| g * g).sum::().sqrt(); + if grad_norm > max_grad_norm { + let clip_scale = max_grad_norm / grad_norm; + for g in corrected.iter_mut() { + *g *= clip_scale; + } + } + + // Learning rate decay + let effective_lr = learning_rate / (1.0 + self.n_updates as f64 / decay_tau); + + for j in 0..n { + self.logits[j] -= effective_lr * corrected[j]; + } + + // Polyak averaging of weights in the simplex + let raw_weights = softmax(&self.logits); + for j in 0..n { + self.weights_avg[j] = avg_decay * self.weights_avg[j] + (1.0 - avg_decay) * raw_weights[j]; + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 1d0d381..7e2c89c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,9 +5,15 @@ pub mod bm25_scorer; pub mod bayesian_scorer; pub mod vector_scorer; pub mod hybrid_scorer; +pub mod fusion; pub mod parameter_learner; pub mod experiments; pub mod defaults; +pub mod probability; +pub mod learnable_weights; +pub mod attention_weights; +pub mod metrics; +pub mod debug; #[cfg(feature = "python")] mod pybindings; @@ -17,13 +23,26 @@ pub use math_utils::{ cosine_similarity, dot_product, logit, + min_max_normalize, safe_log, safe_prob, sigmoid, + softmax, + softmax_rows, vector_magnitude, EPSILON, }; +pub use fusion::{ + balanced_log_odds_fusion, + cosine_to_probability, + log_odds_conjunction, + prob_and, + prob_not, + prob_or, + Gating, +}; + pub use tokenizer::Tokenizer; pub use corpus::{Corpus, Document}; pub use bm25_scorer::BM25Scorer; @@ -33,3 +52,23 @@ pub use hybrid_scorer::HybridScorer; pub use parameter_learner::{ParameterLearner, ParameterLearnerResult}; pub use experiments::{ExperimentRunner, Query}; pub use defaults::{build_default_corpus, build_default_queries}; +pub use probability::{BayesianProbabilityTransform, TrainingMode}; +pub use learnable_weights::LearnableLogOddsWeights; +pub use attention_weights::AttentionLogOddsWeights; +pub use metrics::{ + brier_score, + calibration_report, + expected_calibration_error, + reliability_diagram, + CalibrationReport, +}; +pub use debug::{ + BM25SignalTrace, + ComparisonResult, + DocumentTrace, + FusionDebugger, + FusionTrace, + NotTrace, + SignalTrace, + VectorSignalTrace, +}; diff --git a/src/main.rs b/src/main.rs index 74f3b5e..1674da0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -23,8 +23,9 @@ fn main() { println!(); let mut all_passed = true; - for (name, passed, details) in results { - let status = if passed { "PASS" } else { "FAIL" }; + let total = results.len(); + for (name, passed, details) in &results { + let status = if *passed { "PASS" } else { "FAIL" }; if !passed { all_passed = false; } @@ -38,7 +39,7 @@ fn main() { println!("{}", "=".repeat(72)); if all_passed { - println!("All 10 experiments PASSED."); + println!("All {} experiments PASSED.", total); } else { println!("Some experiments FAILED."); } diff --git a/src/math_utils.rs b/src/math_utils.rs index 3b836b3..9d2bf39 100644 --- a/src/math_utils.rs +++ b/src/math_utils.rs @@ -49,3 +49,46 @@ pub fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 { } dot_product(a, b) / (mag_a * mag_b) } + +/// Numerically stable softmax over a 1D slice. +/// +/// Shifts by max to prevent overflow, then normalizes. +pub fn softmax(z: &[f64]) -> Vec { + if z.is_empty() { + return Vec::new(); + } + let max_z = z.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let exp_z: Vec = z.iter().map(|&v| (v - max_z).exp()).collect(); + let sum: f64 = exp_z.iter().sum(); + exp_z.iter().map(|&e| e / sum).collect() +} + +/// Row-wise softmax over a 2D array stored as a flat slice. +/// +/// Each row of `n_cols` elements gets an independent softmax. +pub fn softmax_rows(z: &[f64], n_cols: usize) -> Vec { + let n_rows = z.len() / n_cols; + let mut result = vec![0.0; z.len()]; + for r in 0..n_rows { + let start = r * n_cols; + let end = start + n_cols; + let row = &z[start..end]; + let sm = softmax(row); + result[start..end].copy_from_slice(&sm); + } + result +} + +/// Min-max normalize a slice to [0, 1]. Returns zeros if range is negligible. +pub fn min_max_normalize(values: &[f64]) -> Vec { + if values.is_empty() { + return Vec::new(); + } + let min_val = values.iter().copied().fold(f64::INFINITY, f64::min); + let max_val = values.iter().copied().fold(f64::NEG_INFINITY, f64::max); + let range = max_val - min_val; + if range < 1e-12 { + return vec![0.0; values.len()]; + } + values.iter().map(|&v| (v - min_val) / range).collect() +} diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 0000000..2c1bc94 --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,150 @@ +/// Calibration metrics for evaluating probability quality. + +/// A single bin in a reliability diagram: (avg_predicted, avg_actual, count). +pub type ReliabilityBin = (f64, f64, usize); + +/// Expected Calibration Error (ECE). +/// +/// Measures how well predicted probabilities match actual relevance rates. +/// Lower is better. Perfect calibration = 0. +pub fn expected_calibration_error( + probabilities: &[f64], + labels: &[f64], + n_bins: usize, +) -> f64 { + let total = probabilities.len() as f64; + let mut ece = 0.0; + + for bin_idx in 0..n_bins { + let lo = bin_idx as f64 / n_bins as f64; + let hi = (bin_idx + 1) as f64 / n_bins as f64; + + let mut sum_prob = 0.0; + let mut sum_label = 0.0; + let mut count = 0usize; + + for (i, &p) in probabilities.iter().enumerate() { + let in_bin = if bin_idx == 0 { + p >= lo && p <= hi + } else { + p > lo && p <= hi + }; + if in_bin { + sum_prob += p; + sum_label += labels[i]; + count += 1; + } + } + + if count == 0 { + continue; + } + + let avg_prob = sum_prob / count as f64; + let avg_label = sum_label / count as f64; + ece += (count as f64 / total) * (avg_prob - avg_label).abs(); + } + + ece +} + +/// Brier score: mean squared error between probabilities and labels. +/// +/// Decomposes into calibration + discrimination. Lower is better. +pub fn brier_score(probabilities: &[f64], labels: &[f64]) -> f64 { + let n = probabilities.len() as f64; + probabilities + .iter() + .zip(labels.iter()) + .map(|(&p, &y)| (p - y) * (p - y)) + .sum::() + / n +} + +/// Compute reliability diagram data: (avg_predicted, avg_actual, count) per bin. +/// +/// Perfect calibration means avg_predicted == avg_actual for every bin. +pub fn reliability_diagram( + probabilities: &[f64], + labels: &[f64], + n_bins: usize, +) -> Vec { + let mut bins = Vec::new(); + + for bin_idx in 0..n_bins { + let lo = bin_idx as f64 / n_bins as f64; + let hi = (bin_idx + 1) as f64 / n_bins as f64; + + let mut sum_prob = 0.0; + let mut sum_label = 0.0; + let mut count = 0usize; + + for (i, &p) in probabilities.iter().enumerate() { + let in_bin = if bin_idx == 0 { + p >= lo && p <= hi + } else { + p > lo && p <= hi + }; + if in_bin { + sum_prob += p; + sum_label += labels[i]; + count += 1; + } + } + + if count > 0 { + bins.push((sum_prob / count as f64, sum_label / count as f64, count)); + } + } + + bins +} + +/// One-call calibration diagnostic report. +pub struct CalibrationReport { + pub ece: f64, + pub brier: f64, + pub reliability: Vec, + pub n_samples: usize, + pub n_bins: usize, +} + +impl CalibrationReport { + /// Formatted text summary of calibration metrics. + pub fn summary(&self) -> String { + let mut lines = vec![ + "Calibration Report".to_string(), + "==================".to_string(), + format!(" Samples : {}", self.n_samples), + format!(" Bins : {}", self.n_bins), + format!(" ECE : {:.6}", self.ece), + format!(" Brier : {:.6}", self.brier), + String::new(), + " Reliability Diagram".to_string(), + " -------------------".to_string(), + format!(" {:>10} {:>10} {:>6}", "Predicted", "Actual", "Count"), + ]; + for &(avg_pred, avg_actual, count) in &self.reliability { + lines.push(format!( + " {:>10.4} {:>10.4} {:>6}", + avg_pred, avg_actual, count + )); + } + lines.join("\n") + } +} + +/// Compute a full calibration diagnostic report in one call. +pub fn calibration_report( + probabilities: &[f64], + labels: &[f64], + n_bins: usize, +) -> CalibrationReport { + CalibrationReport { + ece: expected_calibration_error(probabilities, labels, n_bins), + brier: brier_score(probabilities, labels), + reliability: reliability_diagram(probabilities, labels, n_bins), + n_samples: probabilities.len(), + n_bins, + } +} diff --git a/src/probability.rs b/src/probability.rs new file mode 100644 index 0000000..8cade15 --- /dev/null +++ b/src/probability.rs @@ -0,0 +1,336 @@ +use crate::math_utils::{clamp, safe_prob, sigmoid, EPSILON}; + +const ALPHA_MIN: f64 = 0.01; + +/// Training mode for parameter learning (C1/C2/C3 conditions). +#[derive(Clone, Copy, Debug, Default, PartialEq)] +pub enum TrainingMode { + /// C1: Train on sigmoid likelihood pred = sigmoid(alpha*(s-beta)). + #[default] + Balanced, + /// C2: Train on full Bayesian posterior with composite prior. + PriorAware, + /// C3: Same training as balanced, but at inference prior=0.5. + PriorFree, +} + +/// Transforms raw BM25 scores into calibrated probabilities. +/// +/// Implements sigmoid likelihood + composite prior + Bayesian posterior +/// with optional base_rate correction (two-step Bayes update). +/// +/// Supports batch fitting (gradient descent) and online learning +/// (SGD with EMA, Polyak averaging, gradient clipping). +pub struct BayesianProbabilityTransform { + pub alpha: f64, + pub beta: f64, + pub base_rate: Option, + #[allow(dead_code)] + logit_base_rate: Option, + training_mode: TrainingMode, + n_updates: usize, + grad_alpha_ema: f64, + grad_beta_ema: f64, + alpha_avg: f64, + beta_avg: f64, +} + +impl BayesianProbabilityTransform { + pub fn new(alpha: f64, beta: f64, base_rate: Option) -> Self { + if let Some(br) = base_rate { + assert!( + br > 0.0 && br < 1.0, + "base_rate must be in (0, 1), got {}", + br + ); + } + let logit_br = base_rate.map(|br| { + let br = clamp(br, EPSILON, 1.0 - EPSILON); + (br / (1.0 - br)).ln() + }); + Self { + alpha, + beta, + base_rate, + logit_base_rate: logit_br, + training_mode: TrainingMode::Balanced, + n_updates: 0, + grad_alpha_ema: 0.0, + grad_beta_ema: 0.0, + alpha_avg: alpha, + beta_avg: beta, + } + } + + /// EMA-averaged alpha for stable inference after online updates. + pub fn averaged_alpha(&self) -> f64 { + self.alpha_avg + } + + /// EMA-averaged beta for stable inference after online updates. + pub fn averaged_beta(&self) -> f64 { + self.beta_avg + } + + /// Current training mode. + pub fn training_mode(&self) -> TrainingMode { + self.training_mode + } + + /// Sigmoid likelihood: sigma(alpha * (score - beta)). + pub fn likelihood(&self, score: f64) -> f64 { + sigmoid(self.alpha * (score - self.beta)) + } + + /// Term-frequency prior: 0.2 + 0.7 * min(1, tf / 10). + pub fn tf_prior(tf: f64) -> f64 { + 0.2 + 0.7 * (tf / 10.0).min(1.0) + } + + /// Document-length normalization prior (Eq. 26). + /// + /// P_norm = 0.3 + 0.6 * (1 - min(1, |doc_len_ratio - 0.5| * 2)) + pub fn norm_prior(doc_len_ratio: f64) -> f64 { + 0.3 + 0.6 * (1.0 - ((doc_len_ratio - 0.5).abs() * 2.0).min(1.0)) + } + + /// Composite prior: clamp(0.7 * P_tf + 0.3 * P_norm, 0.1, 0.9). + pub fn composite_prior(tf: f64, doc_len_ratio: f64) -> f64 { + let p_tf = Self::tf_prior(tf); + let p_norm = Self::norm_prior(doc_len_ratio); + clamp(0.7 * p_tf + 0.3 * p_norm, 0.1, 0.9) + } + + /// Bayesian posterior via two-step Bayes update. + /// + /// Without base_rate: P = L*p / (L*p + (1-L)*(1-p)) + /// With base_rate: second Bayes update using base_rate as corpus-level prior. + pub fn posterior(likelihood_val: f64, prior: f64, base_rate: Option) -> f64 { + let l = safe_prob(likelihood_val); + let p = safe_prob(prior); + let numerator = l * p; + let denominator = numerator + (1.0 - l) * (1.0 - p); + let mut result = safe_prob(numerator / denominator); + + if let Some(br) = base_rate { + let num_br = result * br; + let den_br = num_br + (1.0 - result) * (1.0 - br); + result = safe_prob(num_br / den_br); + } + + result + } + + /// Full pipeline: BM25 score -> calibrated probability. + pub fn score_to_probability( + &self, + score: f64, + tf: f64, + doc_len_ratio: f64, + ) -> f64 { + let l_val = self.likelihood(score); + + let prior = if self.training_mode == TrainingMode::PriorFree { + 0.5 + } else { + Self::composite_prior(tf, doc_len_ratio) + }; + + Self::posterior(l_val, prior, self.base_rate) + } + + /// WAND upper bound for safe document pruning (Theorem 6.1.2). + pub fn wand_upper_bound(&self, bm25_upper_bound: f64, p_max: f64) -> f64 { + let l_max = self.likelihood(bm25_upper_bound); + Self::posterior(l_max, p_max, self.base_rate) + } + + /// Batch gradient descent to learn alpha and beta (Algorithm 8.3.1). + pub fn fit( + &mut self, + scores: &[f64], + labels: &[f64], + learning_rate: f64, + max_iterations: usize, + tolerance: f64, + mode: TrainingMode, + tfs: Option<&[f64]>, + doc_len_ratios: Option<&[f64]>, + ) { + if mode == TrainingMode::PriorAware { + assert!( + tfs.is_some() && doc_len_ratios.is_some(), + "tfs and doc_len_ratios are required when mode is PriorAware" + ); + } + + let priors: Option> = if mode == TrainingMode::PriorAware { + let tfs = tfs.unwrap(); + let dlrs = doc_len_ratios.unwrap(); + Some( + tfs.iter() + .zip(dlrs.iter()) + .map(|(&tf, &dlr)| Self::composite_prior(tf, dlr)) + .collect(), + ) + } else { + Option::None + }; + + let mut alpha = self.alpha; + let mut beta = self.beta; + let n = scores.len() as f64; + + for _ in 0..max_iterations { + let (grad_alpha, grad_beta) = if mode == TrainingMode::PriorAware { + let priors_ref = priors.as_ref().unwrap(); + compute_prior_aware_gradients(scores, labels, priors_ref, alpha, beta, n) + } else { + compute_balanced_gradients(scores, labels, alpha, beta, n) + }; + + let new_alpha = alpha - learning_rate * grad_alpha; + let new_beta = beta - learning_rate * grad_beta; + + if (new_alpha - alpha).abs() < tolerance && (new_beta - beta).abs() < tolerance { + alpha = new_alpha; + beta = new_beta; + break; + } + + alpha = new_alpha; + beta = new_beta; + } + + self.alpha = alpha; + self.beta = beta; + self.training_mode = mode; + self.n_updates = 0; + self.grad_alpha_ema = 0.0; + self.grad_beta_ema = 0.0; + self.alpha_avg = alpha; + self.beta_avg = beta; + } + + /// Online SGD update from a single observation or mini-batch. + pub fn update( + &mut self, + scores: &[f64], + labels: &[f64], + learning_rate: f64, + momentum: f64, + decay_tau: f64, + max_grad_norm: f64, + avg_decay: f64, + mode: Option, + tfs: Option<&[f64]>, + doc_len_ratios: Option<&[f64]>, + ) { + let effective_mode = mode.unwrap_or(self.training_mode); + if effective_mode == TrainingMode::PriorAware { + assert!( + tfs.is_some() && doc_len_ratios.is_some(), + "tfs and doc_len_ratios are required when mode is PriorAware" + ); + } + + let n = scores.len() as f64; + + let (grad_alpha, grad_beta) = if effective_mode == TrainingMode::PriorAware { + let tfs = tfs.unwrap(); + let dlrs = doc_len_ratios.unwrap(); + let priors: Vec = tfs + .iter() + .zip(dlrs.iter()) + .map(|(&tf, &dlr)| Self::composite_prior(tf, dlr)) + .collect(); + compute_prior_aware_gradients(scores, labels, &priors, self.alpha, self.beta, n) + } else { + compute_balanced_gradients(scores, labels, self.alpha, self.beta, n) + }; + + if mode.is_some() { + self.training_mode = effective_mode; + } + + // EMA smoothing + self.grad_alpha_ema = momentum * self.grad_alpha_ema + (1.0 - momentum) * grad_alpha; + self.grad_beta_ema = momentum * self.grad_beta_ema + (1.0 - momentum) * grad_beta; + + // Bias correction + self.n_updates += 1; + let correction = 1.0 - momentum.powi(self.n_updates as i32); + let mut corrected_alpha = self.grad_alpha_ema / correction; + let mut corrected_beta = self.grad_beta_ema / correction; + + // Gradient clipping + let grad_norm = (corrected_alpha * corrected_alpha + corrected_beta * corrected_beta).sqrt(); + if grad_norm > max_grad_norm { + let scale = max_grad_norm / grad_norm; + corrected_alpha *= scale; + corrected_beta *= scale; + } + + // Learning rate decay + let effective_lr = learning_rate / (1.0 + self.n_updates as f64 / decay_tau); + + self.alpha -= effective_lr * corrected_alpha; + self.beta -= effective_lr * corrected_beta; + + // Alpha must stay positive + if self.alpha < ALPHA_MIN { + self.alpha = ALPHA_MIN; + } + + // Polyak parameter averaging + self.alpha_avg = avg_decay * self.alpha_avg + (1.0 - avg_decay) * self.alpha; + self.beta_avg = avg_decay * self.beta_avg + (1.0 - avg_decay) * self.beta; + } +} + +/// Compute gradients for balanced/prior_free training mode. +fn compute_balanced_gradients( + scores: &[f64], + labels: &[f64], + alpha: f64, + beta: f64, + n: f64, +) -> (f64, f64) { + let mut grad_alpha = 0.0; + let mut grad_beta = 0.0; + for (&s, &y) in scores.iter().zip(labels.iter()) { + let l = safe_prob(sigmoid(alpha * (s - beta))); + let error = l - y; + grad_alpha += error * (s - beta); + grad_beta += error * (-alpha); + } + (grad_alpha / n, grad_beta / n) +} + +/// Compute gradients for prior_aware training mode. +fn compute_prior_aware_gradients( + scores: &[f64], + labels: &[f64], + priors: &[f64], + alpha: f64, + beta: f64, + n: f64, +) -> (f64, f64) { + let mut grad_alpha = 0.0; + let mut grad_beta = 0.0; + for (i, (&s, &y)) in scores.iter().zip(labels.iter()).enumerate() { + let l = safe_prob(sigmoid(alpha * (s - beta))); + let p = priors[i]; + let denom = l * p + (1.0 - l) * (1.0 - p); + let predicted = safe_prob(l * p / denom); + + let dp_dl = p * (1.0 - p) / (denom * denom); + let dl_dalpha = l * (1.0 - l) * (s - beta); + let dl_dbeta = -l * (1.0 - l) * alpha; + + let error = predicted - y; + grad_alpha += error * dp_dl * dl_dalpha; + grad_beta += error * dp_dl * dl_dbeta; + } + (grad_alpha / n, grad_beta / n) +} diff --git a/src/pybindings.rs b/src/pybindings.rs index 7e97cbd..54c59d8 100644 --- a/src/pybindings.rs +++ b/src/pybindings.rs @@ -10,10 +10,40 @@ use crate::bm25_scorer::BM25Scorer; use crate::corpus::{Corpus as CoreCorpus, Document}; use crate::defaults::{build_default_corpus, build_default_queries}; use crate::experiments::{ExperimentRunner, Query}; +use crate::fusion; use crate::hybrid_scorer::HybridScorer; use crate::parameter_learner::{ParameterLearner, ParameterLearnerResult}; use crate::tokenizer::Tokenizer; use crate::vector_scorer::VectorScorer; +use crate::probability::BayesianProbabilityTransform; +use crate::learnable_weights::LearnableLogOddsWeights; +use crate::attention_weights::AttentionLogOddsWeights; +use crate::metrics; +use crate::debug::{FusionDebugger, BM25SignalTrace, VectorSignalTrace, FusionTrace, DocumentTrace, ComparisonResult, NotTrace}; + +fn parse_gating(gating: Option<&str>) -> PyResult { + match gating { + None | Some("none") => Ok(fusion::Gating::NoGating), + Some("relu") => Ok(fusion::Gating::Relu), + Some("swish") => Ok(fusion::Gating::Swish), + Some(other) => Err(PyValueError::new_err(format!( + "gating must be 'none', 'relu', or 'swish', got '{}'", + other + ))), + } +} + +fn parse_training_mode(mode: Option<&str>) -> PyResult { + match mode { + None | Some("balanced") => Ok(crate::probability::TrainingMode::Balanced), + Some("prior_aware") => Ok(crate::probability::TrainingMode::PriorAware), + Some("prior_free") => Ok(crate::probability::TrainingMode::PriorFree), + Some(other) => Err(PyValueError::new_err(format!( + "mode must be 'balanced', 'prior_aware', or 'prior_free', got '{}'", + other + ))), + } +} #[pyclass(name = "Tokenizer")] pub struct PyTokenizer { @@ -120,6 +150,7 @@ impl PyCorpus { #[pymethods] impl PyCorpus { #[new] + #[pyo3(signature = (_tokenizer=None))] fn new(_tokenizer: Option<&PyTokenizer>) -> Self { let core = CoreCorpus::new(Tokenizer::new()); Self { @@ -203,6 +234,7 @@ pub struct PyBM25Scorer { #[pymethods] impl PyBM25Scorer { #[new] + #[pyo3(signature = (corpus, k1=None, b=None))] fn new(corpus: &PyCorpus, k1: Option, b: Option) -> PyResult { let corpus = corpus.shared_corpus()?; Ok(Self { @@ -243,12 +275,14 @@ pub struct PyBayesianBM25Scorer { #[pymethods] impl PyBayesianBM25Scorer { #[new] - fn new(bm25: &PyBM25Scorer, alpha: Option, beta: Option) -> Self { + #[pyo3(signature = (bm25, alpha=None, beta=None, base_rate=None))] + fn new(bm25: &PyBM25Scorer, alpha: Option, beta: Option, base_rate: Option) -> Self { Self { inner: Rc::new(BayesianBM25Scorer::new( Rc::clone(&bm25.inner), alpha.unwrap_or(1.0), beta.unwrap_or(0.5), + base_rate, )), } } @@ -280,6 +314,11 @@ impl PyBayesianBM25Scorer { fn score(&self, query_terms: Vec, doc: &PyDocument) -> f64 { self.inner.score(&query_terms, &doc.inner) } + + #[getter] + fn base_rate(&self) -> Option { + self.inner.base_rate() + } } #[pyclass(unsendable, name = "VectorScorer")] @@ -357,6 +396,7 @@ pub struct PyParameterLearner { #[pymethods] impl PyParameterLearner { #[new] + #[pyo3(signature = (learning_rate=None, max_iterations=None, tolerance=None))] fn new(learning_rate: Option, max_iterations: Option, tolerance: Option) -> Self { Self { inner: ParameterLearner::new( @@ -466,6 +506,7 @@ pub struct PyExperimentRunner { #[pymethods] impl PyExperimentRunner { #[new] + #[pyo3(signature = (corpus, queries, k1=None, b=None))] fn new(corpus: &PyCorpus, queries: Vec>, k1: Option, b: Option) -> PyResult { let corpus = corpus.shared_corpus()?; let mut query_list = Vec::with_capacity(queries.len()); @@ -522,6 +563,853 @@ fn build_default_queries_py(py: Python) -> PyResult>> { Ok(out) } +#[pyfunction(name = "prob_not")] +fn prob_not_py(prob: f64) -> f64 { + fusion::prob_not(prob) +} + +#[pyfunction(name = "prob_and")] +fn prob_and_py(probs: Vec) -> f64 { + fusion::prob_and(&probs) +} + +#[pyfunction(name = "prob_or")] +fn prob_or_py(probs: Vec) -> f64 { + fusion::prob_or(&probs) +} + +#[pyfunction(name = "cosine_to_probability")] +fn cosine_to_probability_py(score: f64) -> f64 { + fusion::cosine_to_probability(score) +} + +#[pyfunction(name = "log_odds_conjunction")] +#[pyo3(signature = (probs, alpha=None, weights=None, gating=None))] +fn log_odds_conjunction_py( + probs: Vec, + alpha: Option, + weights: Option>, + gating: Option<&str>, +) -> PyResult { + if let Some(ref w) = weights { + if w.len() != probs.len() { + return Err(PyValueError::new_err( + "weights length must match probs length", + )); + } + if !w.iter().all(|&wi| wi >= 0.0) { + return Err(PyValueError::new_err("all weights must be non-negative")); + } + if (w.iter().sum::() - 1.0).abs() >= 1e-6 { + return Err(PyValueError::new_err("weights must sum to 1.0")); + } + } + let g = parse_gating(gating)?; + Ok(fusion::log_odds_conjunction( + &probs, + alpha, + weights.as_deref(), + g, + )) +} + +#[pyfunction(name = "balanced_log_odds_fusion")] +#[pyo3(signature = (sparse_probs, dense_similarities, weight=None))] +fn balanced_log_odds_fusion_py( + sparse_probs: Vec, + dense_similarities: Vec, + weight: Option, +) -> PyResult> { + if sparse_probs.len() != dense_similarities.len() { + return Err(PyValueError::new_err( + "sparse_probs and dense_similarities must have the same length", + )); + } + Ok(fusion::balanced_log_odds_fusion( + &sparse_probs, + &dense_similarities, + weight.unwrap_or(0.5), + )) +} + +// --------------------------------------------------------------------------- +// BayesianProbabilityTransform +// --------------------------------------------------------------------------- + +#[pyclass(unsendable, name = "BayesianProbabilityTransform")] +pub struct PyBayesianProbabilityTransform { + inner: RefCell, +} + +#[pymethods] +impl PyBayesianProbabilityTransform { + #[new] + #[pyo3(signature = (alpha=None, beta=None, base_rate=None))] + fn new(alpha: Option, beta: Option, base_rate: Option) -> PyResult { + if let Some(br) = base_rate { + if br <= 0.0 || br >= 1.0 { + return Err(PyValueError::new_err(format!( + "base_rate must be in (0, 1), got {}", + br + ))); + } + } + Ok(Self { + inner: RefCell::new(BayesianProbabilityTransform::new( + alpha.unwrap_or(1.0), + beta.unwrap_or(0.0), + base_rate, + )), + }) + } + + #[getter] + fn alpha(&self) -> f64 { + self.inner.borrow().alpha + } + + #[getter] + fn beta(&self) -> f64 { + self.inner.borrow().beta + } + + #[getter] + fn base_rate(&self) -> Option { + self.inner.borrow().base_rate + } + + #[getter] + fn averaged_alpha(&self) -> f64 { + self.inner.borrow().averaged_alpha() + } + + #[getter] + fn averaged_beta(&self) -> f64 { + self.inner.borrow().averaged_beta() + } + + fn likelihood(&self, score: f64) -> f64 { + self.inner.borrow().likelihood(score) + } + + #[staticmethod] + fn tf_prior(tf: f64) -> f64 { + BayesianProbabilityTransform::tf_prior(tf) + } + + #[staticmethod] + fn norm_prior(doc_len_ratio: f64) -> f64 { + BayesianProbabilityTransform::norm_prior(doc_len_ratio) + } + + #[staticmethod] + fn composite_prior(tf: f64, doc_len_ratio: f64) -> f64 { + BayesianProbabilityTransform::composite_prior(tf, doc_len_ratio) + } + + #[staticmethod] + #[pyo3(signature = (likelihood_val, prior, base_rate=None))] + fn posterior(likelihood_val: f64, prior: f64, base_rate: Option) -> f64 { + BayesianProbabilityTransform::posterior(likelihood_val, prior, base_rate) + } + + fn score_to_probability(&self, score: f64, tf: f64, doc_len_ratio: f64) -> f64 { + self.inner.borrow().score_to_probability(score, tf, doc_len_ratio) + } + + #[pyo3(signature = (bm25_upper_bound, p_max=None))] + fn wand_upper_bound(&self, bm25_upper_bound: f64, p_max: Option) -> f64 { + self.inner.borrow().wand_upper_bound(bm25_upper_bound, p_max.unwrap_or(0.9)) + } + + #[pyo3(signature = (scores, labels, learning_rate=None, max_iterations=None, tolerance=None, mode=None, tfs=None, doc_len_ratios=None))] + fn fit( + &self, + scores: Vec, + labels: Vec, + learning_rate: Option, + max_iterations: Option, + tolerance: Option, + mode: Option<&str>, + tfs: Option>, + doc_len_ratios: Option>, + ) -> PyResult<()> { + let m = parse_training_mode(mode)?; + self.inner.borrow_mut().fit( + &scores, + &labels, + learning_rate.unwrap_or(0.01), + max_iterations.unwrap_or(1000), + tolerance.unwrap_or(1e-6), + m, + tfs.as_deref(), + doc_len_ratios.as_deref(), + ); + Ok(()) + } + + #[pyo3(signature = (score, label, learning_rate=None, momentum=None, decay_tau=None, max_grad_norm=None, avg_decay=None, mode=None, tf=None, doc_len_ratio=None))] + #[allow(clippy::too_many_arguments)] + fn update( + &self, + score: Vec, + label: Vec, + learning_rate: Option, + momentum: Option, + decay_tau: Option, + max_grad_norm: Option, + avg_decay: Option, + mode: Option<&str>, + tf: Option>, + doc_len_ratio: Option>, + ) -> PyResult<()> { + let m = if mode.is_some() { + Some(parse_training_mode(mode)?) + } else { + None + }; + self.inner.borrow_mut().update( + &score, + &label, + learning_rate.unwrap_or(0.01), + momentum.unwrap_or(0.9), + decay_tau.unwrap_or(1000.0), + max_grad_norm.unwrap_or(1.0), + avg_decay.unwrap_or(0.995), + m, + tf.as_deref(), + doc_len_ratio.as_deref(), + ); + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// LearnableLogOddsWeights +// --------------------------------------------------------------------------- + +#[pyclass(unsendable, name = "LearnableLogOddsWeights")] +pub struct PyLearnableLogOddsWeights { + inner: RefCell, +} + +#[pymethods] +impl PyLearnableLogOddsWeights { + #[new] + #[pyo3(signature = (n_signals, alpha=None))] + fn new(n_signals: usize, alpha: Option) -> PyResult { + if n_signals < 1 { + return Err(PyValueError::new_err(format!( + "n_signals must be >= 1, got {}", + n_signals + ))); + } + Ok(Self { + inner: RefCell::new(LearnableLogOddsWeights::new( + n_signals, + alpha.unwrap_or(0.0), + )), + }) + } + + #[getter] + fn n_signals(&self) -> usize { + self.inner.borrow().n_signals() + } + + #[getter] + fn alpha(&self) -> f64 { + self.inner.borrow().alpha() + } + + #[getter] + fn weights(&self) -> Vec { + self.inner.borrow().weights() + } + + #[getter] + fn averaged_weights(&self) -> Vec { + self.inner.borrow().averaged_weights() + } + + #[pyo3(signature = (probs, use_averaged=None))] + fn combine(&self, probs: Vec, use_averaged: Option) -> f64 { + self.inner.borrow().combine(&probs, use_averaged.unwrap_or(false)) + } + + #[pyo3(signature = (probs, labels, learning_rate=None, max_iterations=None, tolerance=None))] + fn fit( + &self, + probs: Vec>, + labels: Vec, + learning_rate: Option, + max_iterations: Option, + tolerance: Option, + ) -> PyResult<()> { + self.inner.borrow_mut().fit( + &probs, + &labels, + learning_rate.unwrap_or(0.01), + max_iterations.unwrap_or(1000), + tolerance.unwrap_or(1e-6), + ); + Ok(()) + } + + #[pyo3(signature = (probs, label, learning_rate=None, momentum=None, decay_tau=None, max_grad_norm=None, avg_decay=None))] + fn update( + &self, + probs: Vec>, + label: Vec, + learning_rate: Option, + momentum: Option, + decay_tau: Option, + max_grad_norm: Option, + avg_decay: Option, + ) -> PyResult<()> { + self.inner.borrow_mut().update( + &probs, + &label, + learning_rate.unwrap_or(0.01), + momentum.unwrap_or(0.9), + decay_tau.unwrap_or(1000.0), + max_grad_norm.unwrap_or(1.0), + avg_decay.unwrap_or(0.995), + ); + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// AttentionLogOddsWeights +// --------------------------------------------------------------------------- + +#[pyclass(unsendable, name = "AttentionLogOddsWeights")] +pub struct PyAttentionLogOddsWeights { + inner: RefCell, +} + +#[pymethods] +impl PyAttentionLogOddsWeights { + #[new] + #[pyo3(signature = (n_signals, n_query_features, alpha=None, normalize=None))] + fn new( + n_signals: usize, + n_query_features: usize, + alpha: Option, + normalize: Option, + ) -> PyResult { + if n_signals < 1 { + return Err(PyValueError::new_err(format!( + "n_signals must be >= 1, got {}", + n_signals + ))); + } + if n_query_features < 1 { + return Err(PyValueError::new_err(format!( + "n_query_features must be >= 1, got {}", + n_query_features + ))); + } + Ok(Self { + inner: RefCell::new(AttentionLogOddsWeights::new( + n_signals, + n_query_features, + alpha.unwrap_or(0.5), + normalize.unwrap_or(false), + )), + }) + } + + #[getter] + fn n_signals(&self) -> usize { + self.inner.borrow().n_signals() + } + + #[getter] + fn n_query_features(&self) -> usize { + self.inner.borrow().n_query_features() + } + + #[getter] + fn alpha(&self) -> f64 { + self.inner.borrow().alpha() + } + + #[getter] + fn normalize(&self) -> bool { + self.inner.borrow().normalize() + } + + #[getter] + fn weights_matrix(&self) -> Vec { + self.inner.borrow().weights_matrix() + } + + #[pyo3(signature = (probs, m, query_features, m_q, use_averaged=None))] + fn combine( + &self, + probs: Vec, + m: usize, + query_features: Vec, + m_q: usize, + use_averaged: Option, + ) -> Vec { + self.inner.borrow().combine( + &probs, + m, + &query_features, + m_q, + use_averaged.unwrap_or(false), + ) + } + + #[pyo3(signature = (probs, labels, query_features, m, query_ids=None, learning_rate=None, max_iterations=None, tolerance=None))] + #[allow(clippy::too_many_arguments)] + fn fit( + &self, + probs: Vec, + labels: Vec, + query_features: Vec, + m: usize, + query_ids: Option>, + learning_rate: Option, + max_iterations: Option, + tolerance: Option, + ) -> PyResult<()> { + self.inner.borrow_mut().fit( + &probs, + &labels, + &query_features, + m, + query_ids.as_deref(), + learning_rate.unwrap_or(0.01), + max_iterations.unwrap_or(1000), + tolerance.unwrap_or(1e-6), + ); + Ok(()) + } + + #[pyo3(signature = (probs, labels, query_features, m, learning_rate=None, momentum=None, decay_tau=None, max_grad_norm=None, avg_decay=None))] + #[allow(clippy::too_many_arguments)] + fn update( + &self, + probs: Vec, + labels: Vec, + query_features: Vec, + m: usize, + learning_rate: Option, + momentum: Option, + decay_tau: Option, + max_grad_norm: Option, + avg_decay: Option, + ) -> PyResult<()> { + self.inner.borrow_mut().update( + &probs, + &labels, + &query_features, + m, + learning_rate.unwrap_or(0.01), + momentum.unwrap_or(0.9), + decay_tau.unwrap_or(1000.0), + max_grad_norm.unwrap_or(1.0), + avg_decay.unwrap_or(0.995), + ); + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Calibration Metrics +// --------------------------------------------------------------------------- + +#[pyclass(name = "CalibrationReport")] +pub struct PyCalibrationReport { + #[pyo3(get)] + ece: f64, + #[pyo3(get)] + brier: f64, + #[pyo3(get)] + reliability: Vec<(f64, f64, usize)>, + #[pyo3(get)] + n_samples: usize, + #[pyo3(get)] + n_bins: usize, +} + +#[pymethods] +impl PyCalibrationReport { + fn summary(&self) -> String { + let report = metrics::CalibrationReport { + ece: self.ece, + brier: self.brier, + reliability: self.reliability.clone(), + n_samples: self.n_samples, + n_bins: self.n_bins, + }; + report.summary() + } +} + +#[pyfunction(name = "expected_calibration_error")] +#[pyo3(signature = (probabilities, labels, n_bins=None))] +fn expected_calibration_error_py( + probabilities: Vec, + labels: Vec, + n_bins: Option, +) -> f64 { + metrics::expected_calibration_error(&probabilities, &labels, n_bins.unwrap_or(10)) +} + +#[pyfunction(name = "brier_score")] +fn brier_score_py(probabilities: Vec, labels: Vec) -> f64 { + metrics::brier_score(&probabilities, &labels) +} + +#[pyfunction(name = "reliability_diagram")] +#[pyo3(signature = (probabilities, labels, n_bins=None))] +fn reliability_diagram_py( + probabilities: Vec, + labels: Vec, + n_bins: Option, +) -> Vec<(f64, f64, usize)> { + metrics::reliability_diagram(&probabilities, &labels, n_bins.unwrap_or(10)) +} + +#[pyfunction(name = "calibration_report")] +#[pyo3(signature = (probabilities, labels, n_bins=None))] +fn calibration_report_py( + probabilities: Vec, + labels: Vec, + n_bins: Option, +) -> PyCalibrationReport { + let r = metrics::calibration_report(&probabilities, &labels, n_bins.unwrap_or(10)); + PyCalibrationReport { + ece: r.ece, + brier: r.brier, + reliability: r.reliability, + n_samples: r.n_samples, + n_bins: r.n_bins, + } +} + +// --------------------------------------------------------------------------- +// FusionDebugger +// --------------------------------------------------------------------------- + +#[pyclass(name = "BM25SignalTrace")] +pub struct PyBM25SignalTrace { + #[pyo3(get)] + raw_score: f64, + #[pyo3(get)] + tf: f64, + #[pyo3(get)] + doc_len_ratio: f64, + #[pyo3(get)] + likelihood: f64, + #[pyo3(get)] + tf_prior: f64, + #[pyo3(get)] + norm_prior: f64, + #[pyo3(get)] + composite_prior: f64, + #[pyo3(get)] + logit_likelihood: f64, + #[pyo3(get)] + logit_prior: f64, + #[pyo3(get)] + logit_base_rate: Option, + #[pyo3(get)] + posterior: f64, + #[pyo3(get)] + alpha: f64, + #[pyo3(get)] + beta: f64, + #[pyo3(get)] + base_rate: Option, +} + +impl PyBM25SignalTrace { + fn from_core(t: &BM25SignalTrace) -> Self { + Self { + raw_score: t.raw_score, + tf: t.tf, + doc_len_ratio: t.doc_len_ratio, + likelihood: t.likelihood, + tf_prior: t.tf_prior, + norm_prior: t.norm_prior, + composite_prior: t.composite_prior, + logit_likelihood: t.logit_likelihood, + logit_prior: t.logit_prior, + logit_base_rate: t.logit_base_rate, + posterior: t.posterior, + alpha: t.alpha, + beta: t.beta, + base_rate: t.base_rate, + } + } +} + +#[pyclass(name = "VectorSignalTrace")] +pub struct PyVectorSignalTrace { + #[pyo3(get)] + cosine_score: f64, + #[pyo3(get)] + probability: f64, + #[pyo3(get)] + logit_probability: f64, +} + +impl PyVectorSignalTrace { + fn from_core(t: &VectorSignalTrace) -> Self { + Self { + cosine_score: t.cosine_score, + probability: t.probability, + logit_probability: t.logit_probability, + } + } +} + +#[pyclass(name = "NotTrace")] +pub struct PyNotTrace { + #[pyo3(get)] + input_probability: f64, + #[pyo3(get)] + input_name: String, + #[pyo3(get)] + complement: f64, + #[pyo3(get)] + logit_input: f64, + #[pyo3(get)] + logit_complement: f64, +} + +impl PyNotTrace { + fn from_core(t: &NotTrace) -> Self { + Self { + input_probability: t.input_probability, + input_name: t.input_name.clone(), + complement: t.complement, + logit_input: t.logit_input, + logit_complement: t.logit_complement, + } + } +} + +#[pyclass(name = "FusionTrace")] +pub struct PyFusionTrace { + #[pyo3(get)] + signal_probabilities: Vec, + #[pyo3(get)] + signal_names: Vec, + #[pyo3(get)] + method: String, + #[pyo3(get)] + logits: Option>, + #[pyo3(get)] + mean_logit: Option, + #[pyo3(get)] + alpha: Option, + #[pyo3(get)] + n_alpha_scale: Option, + #[pyo3(get)] + scaled_logit: Option, + #[pyo3(get)] + weights: Option>, + #[pyo3(get)] + log_probs: Option>, + #[pyo3(get)] + log_prob_sum: Option, + #[pyo3(get)] + complements: Option>, + #[pyo3(get)] + log_complements: Option>, + #[pyo3(get)] + log_complement_sum: Option, + #[pyo3(get)] + fused_probability: f64, +} + +impl PyFusionTrace { + fn from_core(t: &FusionTrace) -> Self { + Self { + signal_probabilities: t.signal_probabilities.clone(), + signal_names: t.signal_names.clone(), + method: t.method.clone(), + logits: t.logits.clone(), + mean_logit: t.mean_logit, + alpha: t.alpha, + n_alpha_scale: t.n_alpha_scale, + scaled_logit: t.scaled_logit, + weights: t.weights.clone(), + log_probs: t.log_probs.clone(), + log_prob_sum: t.log_prob_sum, + complements: t.complements.clone(), + log_complements: t.log_complements.clone(), + log_complement_sum: t.log_complement_sum, + fused_probability: t.fused_probability, + } + } +} + +#[pyclass(name = "DocumentTrace")] +pub struct PyDocumentTrace { + inner: DocumentTrace, +} + +#[pymethods] +impl PyDocumentTrace { + #[getter] + fn doc_id(&self) -> Option { + self.inner.doc_id.clone() + } + + #[getter] + fn final_probability(&self) -> f64 { + self.inner.final_probability + } + + #[getter] + fn fusion(&self) -> PyFusionTrace { + PyFusionTrace::from_core(&self.inner.fusion) + } +} + +#[pyclass(name = "ComparisonResult")] +pub struct PyComparisonResult { + inner: ComparisonResult, +} + +#[pymethods] +impl PyComparisonResult { + #[getter] + fn signal_deltas(&self) -> Vec<(String, f64)> { + self.inner.signal_deltas.clone() + } + + #[getter] + fn dominant_signal(&self) -> String { + self.inner.dominant_signal.clone() + } + + #[getter] + fn crossover_stage(&self) -> Option { + self.inner.crossover_stage.clone() + } +} + +#[pyclass(name = "FusionDebugger")] +pub struct PyFusionDebugger { + inner: FusionDebugger, +} + +#[pymethods] +impl PyFusionDebugger { + #[new] + #[pyo3(signature = (alpha=None, beta=None, base_rate=None))] + fn new(alpha: Option, beta: Option, base_rate: Option) -> PyResult { + if let Some(br) = base_rate { + if br <= 0.0 || br >= 1.0 { + return Err(PyValueError::new_err(format!( + "base_rate must be in (0, 1), got {}", + br + ))); + } + } + let transform = BayesianProbabilityTransform::new( + alpha.unwrap_or(1.0), + beta.unwrap_or(0.0), + base_rate, + ); + Ok(Self { + inner: FusionDebugger::new(transform), + }) + } + + fn trace_bm25(&self, score: f64, tf: f64, doc_len_ratio: f64) -> PyBM25SignalTrace { + PyBM25SignalTrace::from_core(&self.inner.trace_bm25(score, tf, doc_len_ratio)) + } + + fn trace_vector(&self, cosine_score: f64) -> PyVectorSignalTrace { + PyVectorSignalTrace::from_core(&self.inner.trace_vector(cosine_score)) + } + + #[pyo3(signature = (probability, name=None))] + fn trace_not(&self, probability: f64, name: Option<&str>) -> PyNotTrace { + PyNotTrace::from_core(&self.inner.trace_not(probability, name.unwrap_or("signal"))) + } + + #[pyo3(signature = (probabilities, names=None, method=None, alpha=None, weights=None))] + fn trace_fusion( + &self, + probabilities: Vec, + names: Option>, + method: Option<&str>, + alpha: Option, + weights: Option>, + ) -> PyFusionTrace { + let trace = self.inner.trace_fusion( + &probabilities, + names.as_deref(), + method.unwrap_or("log_odds"), + alpha, + weights.as_deref(), + ); + PyFusionTrace::from_core(&trace) + } + + #[pyo3(signature = (bm25_score=None, tf=None, doc_len_ratio=None, cosine_score=None, method=None, alpha=None, weights=None, doc_id=None))] + #[allow(clippy::too_many_arguments)] + fn trace_document( + &self, + bm25_score: Option, + tf: Option, + doc_len_ratio: Option, + cosine_score: Option, + method: Option<&str>, + alpha: Option, + weights: Option>, + doc_id: Option<&str>, + ) -> PyResult { + if bm25_score.is_none() && cosine_score.is_none() { + return Err(PyValueError::new_err( + "At least one of bm25_score or cosine_score must be provided", + )); + } + let trace = self.inner.trace_document( + bm25_score, + tf, + doc_len_ratio, + cosine_score, + method.unwrap_or("log_odds"), + alpha, + weights.as_deref(), + doc_id, + ); + Ok(PyDocumentTrace { inner: trace }) + } + + fn compare(&self, trace_a: &PyDocumentTrace, trace_b: &PyDocumentTrace) -> PyComparisonResult { + let result = self.inner.compare(&trace_a.inner, &trace_b.inner); + PyComparisonResult { inner: result } + } + + #[pyo3(signature = (trace, verbose=None))] + fn format_trace(&self, trace: &PyDocumentTrace, verbose: Option) -> String { + self.inner.format_trace(&trace.inner, verbose.unwrap_or(true)) + } + + fn format_summary(&self, trace: &PyDocumentTrace) -> String { + self.inner.format_summary(&trace.inner) + } + + fn format_comparison(&self, comparison: &PyComparisonResult) -> String { + self.inner.format_comparison(&comparison.inner) + } +} + +// --------------------------------------------------------------------------- + #[pyfunction(name = "run_experiments")] fn run_experiments_py() -> Vec { let corpus = Rc::new(build_default_corpus()); @@ -536,6 +1424,7 @@ fn run_experiments_py() -> Vec { #[pymodule] fn bb25(m: &Bound<'_, PyModule>) -> PyResult<()> { + // Core types m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -549,11 +1438,42 @@ fn bb25(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + // Probability transform + m.add_class::()?; + + // Learnable weights + m.add_class::()?; + m.add_class::()?; + + // Calibration metrics + m.add_class::()?; + + // Debug/trace types + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + // Functions m.add_function(wrap_pyfunction!(build_default_corpus_py, m)?)?; m.add_function(wrap_pyfunction!(build_default_queries_py, m)?)?; m.add_function(wrap_pyfunction!(run_experiments_py, m)?)?; + m.add_function(wrap_pyfunction!(prob_not_py, m)?)?; + m.add_function(wrap_pyfunction!(prob_and_py, m)?)?; + m.add_function(wrap_pyfunction!(prob_or_py, m)?)?; + m.add_function(wrap_pyfunction!(cosine_to_probability_py, m)?)?; + m.add_function(wrap_pyfunction!(log_odds_conjunction_py, m)?)?; + m.add_function(wrap_pyfunction!(balanced_log_odds_fusion_py, m)?)?; + m.add_function(wrap_pyfunction!(expected_calibration_error_py, m)?)?; + m.add_function(wrap_pyfunction!(brier_score_py, m)?)?; + m.add_function(wrap_pyfunction!(reliability_diagram_py, m)?)?; + m.add_function(wrap_pyfunction!(calibration_report_py, m)?)?; m.add("__all__", vec![ + // Core types "Tokenizer", "Document", "Corpus", @@ -566,9 +1486,35 @@ fn bb25(m: &Bound<'_, PyModule>) -> PyResult<()> { "Query", "ExperimentResult", "ExperimentRunner", + // Probability transform + "BayesianProbabilityTransform", + // Learnable weights + "LearnableLogOddsWeights", + "AttentionLogOddsWeights", + // Calibration + "CalibrationReport", + // Debug + "BM25SignalTrace", + "VectorSignalTrace", + "NotTrace", + "FusionTrace", + "DocumentTrace", + "ComparisonResult", + "FusionDebugger", + // Functions "build_default_corpus", "build_default_queries", "run_experiments", + "prob_not", + "prob_and", + "prob_or", + "cosine_to_probability", + "log_odds_conjunction", + "balanced_log_odds_fusion", + "expected_calibration_error", + "brier_score", + "reliability_diagram", + "calibration_report", ])?; Ok(()) diff --git a/tests/test_debug.py b/tests/test_debug.py new file mode 100644 index 0000000..952f9ac --- /dev/null +++ b/tests/test_debug.py @@ -0,0 +1,206 @@ +import unittest + +import bb25 as bb + + +class TestFusionDebuggerCreation(unittest.TestCase): + def test_default(self): + d = bb.FusionDebugger() + self.assertIsNotNone(d) + + def test_with_params(self): + d = bb.FusionDebugger(alpha=2.0, beta=1.0, base_rate=0.1) + self.assertIsNotNone(d) + + +class TestTraceBM25(unittest.TestCase): + def test_basic_trace(self): + d = bb.FusionDebugger(alpha=1.0, beta=0.5) + trace = d.trace_bm25(2.0, 5.0, 0.5) + self.assertAlmostEqual(trace.raw_score, 2.0) + self.assertAlmostEqual(trace.tf, 5.0) + self.assertAlmostEqual(trace.doc_len_ratio, 0.5) + self.assertGreater(trace.likelihood, 0.5) + self.assertGreater(trace.posterior, 0.0) + self.assertLess(trace.posterior, 1.0) + self.assertAlmostEqual(trace.alpha, 1.0) + self.assertAlmostEqual(trace.beta, 0.5) + + def test_prior_values(self): + d = bb.FusionDebugger(alpha=1.0, beta=0.5) + trace = d.trace_bm25(2.0, 5.0, 0.5) + # tf_prior for tf=5: 0.2 + 0.7 * 0.5 = 0.55 + self.assertAlmostEqual(trace.tf_prior, 0.55) + # norm_prior at ratio=0.5 peaks at 0.9 + self.assertAlmostEqual(trace.norm_prior, 0.9) + + def test_base_rate_trace(self): + d = bb.FusionDebugger(alpha=1.0, beta=0.5, base_rate=0.1) + trace = d.trace_bm25(2.0, 5.0, 0.5) + self.assertAlmostEqual(trace.base_rate, 0.1) + self.assertIsNotNone(trace.logit_base_rate) + + +class TestTraceVector(unittest.TestCase): + def test_basic(self): + d = bb.FusionDebugger() + trace = d.trace_vector(0.8) + self.assertAlmostEqual(trace.cosine_score, 0.8) + self.assertAlmostEqual(trace.probability, 0.9) + self.assertGreater(trace.logit_probability, 0.0) + + def test_negative_cosine(self): + d = bb.FusionDebugger() + trace = d.trace_vector(-0.5) + self.assertLess(trace.probability, 0.5) + + +class TestTraceNot(unittest.TestCase): + def test_complement(self): + d = bb.FusionDebugger() + trace = d.trace_not(0.8, "BM25") + self.assertAlmostEqual(trace.input_probability, 0.8) + self.assertEqual(trace.input_name, "BM25") + self.assertAlmostEqual(trace.complement, 0.2, places=5) + + def test_logit_sign_flip(self): + d = bb.FusionDebugger() + trace = d.trace_not(0.8, "signal") + self.assertAlmostEqual(trace.logit_input, -trace.logit_complement, places=5) + + +class TestTraceFusion(unittest.TestCase): + def test_log_odds(self): + d = bb.FusionDebugger() + trace = d.trace_fusion([0.8, 0.7]) + self.assertEqual(trace.method, "log_odds") + self.assertIsNotNone(trace.logits) + self.assertIsNotNone(trace.mean_logit) + self.assertIsNotNone(trace.n_alpha_scale) + self.assertGreater(trace.fused_probability, 0.0) + + def test_prob_and(self): + d = bb.FusionDebugger() + trace = d.trace_fusion([0.8, 0.7], method="prob_and") + self.assertEqual(trace.method, "prob_and") + self.assertIsNotNone(trace.log_probs) + self.assertAlmostEqual(trace.fused_probability, 0.8 * 0.7, places=5) + + def test_prob_or(self): + d = bb.FusionDebugger() + trace = d.trace_fusion([0.8, 0.7], method="prob_or") + self.assertEqual(trace.method, "prob_or") + self.assertIsNotNone(trace.complements) + expected = 1.0 - (1.0 - 0.8) * (1.0 - 0.7) + self.assertAlmostEqual(trace.fused_probability, expected, places=5) + + def test_prob_not(self): + d = bb.FusionDebugger() + trace = d.trace_fusion([0.8, 0.7], method="prob_not") + self.assertEqual(trace.method, "prob_not") + expected = (1.0 - 0.8) * (1.0 - 0.7) + self.assertAlmostEqual(trace.fused_probability, expected, places=5) + + def test_weighted_log_odds(self): + d = bb.FusionDebugger() + trace = d.trace_fusion([0.8, 0.7], weights=[0.6, 0.4]) + self.assertEqual(trace.method, "log_odds") + self.assertIsNotNone(trace.weights) + self.assertEqual(len(trace.weights), 2) + + def test_custom_names(self): + d = bb.FusionDebugger() + trace = d.trace_fusion([0.8, 0.7], names=["sparse", "dense"]) + self.assertEqual(trace.signal_names, ["sparse", "dense"]) + + +class TestTraceDocument(unittest.TestCase): + def test_bm25_only(self): + d = bb.FusionDebugger(alpha=1.0, beta=0.5) + trace = d.trace_document(bm25_score=2.0, tf=5.0, doc_len_ratio=0.5, doc_id="d01") + self.assertEqual(trace.doc_id, "d01") + self.assertGreater(trace.final_probability, 0.0) + + def test_vector_only(self): + d = bb.FusionDebugger() + trace = d.trace_document(cosine_score=0.8, doc_id="d02") + self.assertEqual(trace.doc_id, "d02") + self.assertGreater(trace.final_probability, 0.0) + + def test_hybrid(self): + d = bb.FusionDebugger(alpha=1.0, beta=0.5) + trace = d.trace_document( + bm25_score=2.0, tf=5.0, doc_len_ratio=0.5, + cosine_score=0.8, doc_id="d01" + ) + self.assertGreater(trace.final_probability, 0.0) + + def test_no_signals_raises(self): + d = bb.FusionDebugger() + with self.assertRaises(ValueError): + d.trace_document() + + +class TestCompare(unittest.TestCase): + def test_basic_comparison(self): + d = bb.FusionDebugger(alpha=1.0, beta=0.5) + ta = d.trace_document(bm25_score=3.0, tf=5.0, doc_len_ratio=0.5, doc_id="d01") + tb = d.trace_document(bm25_score=1.0, tf=2.0, doc_len_ratio=0.8, doc_id="d02") + cmp = d.compare(ta, tb) + self.assertIsNotNone(cmp.dominant_signal) + self.assertGreater(len(cmp.signal_deltas), 0) + + def test_crossover_detection(self): + d = bb.FusionDebugger(alpha=1.0, beta=0.5) + # d01: strong BM25, weak vector + ta = d.trace_document( + bm25_score=5.0, tf=8.0, doc_len_ratio=0.5, + cosine_score=0.1, doc_id="d01" + ) + # d02: weak BM25, strong vector + tb = d.trace_document( + bm25_score=0.5, tf=1.0, doc_len_ratio=0.8, + cosine_score=0.95, doc_id="d02" + ) + cmp = d.compare(ta, tb) + # Should detect crossover between signals + self.assertIsNotNone(cmp.crossover_stage) + + +class TestFormatting(unittest.TestCase): + def test_format_trace(self): + d = bb.FusionDebugger(alpha=1.0, beta=0.5) + trace = d.trace_document(bm25_score=2.0, tf=5.0, doc_len_ratio=0.5, doc_id="d01") + text = d.format_trace(trace) + self.assertIn("Document: d01", text) + self.assertIn("BM25", text) + self.assertIn("final=", text) + + def test_format_trace_non_verbose(self): + d = bb.FusionDebugger(alpha=1.0, beta=0.5) + trace = d.trace_document(bm25_score=2.0, tf=5.0, doc_len_ratio=0.5, doc_id="d01") + text = d.format_trace(trace, verbose=False) + self.assertIn("Document: d01", text) + # Non-verbose should not have logit details + self.assertNotIn("logit(posterior)", text) + + def test_format_summary(self): + d = bb.FusionDebugger(alpha=1.0, beta=0.5) + trace = d.trace_document(bm25_score=2.0, tf=5.0, doc_len_ratio=0.5, doc_id="d01") + summary = d.format_summary(trace) + self.assertIn("d01", summary) + self.assertIn("Fused=", summary) + + def test_format_comparison(self): + d = bb.FusionDebugger(alpha=1.0, beta=0.5) + ta = d.trace_document(bm25_score=3.0, tf=5.0, doc_len_ratio=0.5, doc_id="d01") + tb = d.trace_document(bm25_score=1.0, tf=2.0, doc_len_ratio=0.8, doc_id="d02") + cmp = d.compare(ta, tb) + text = d.format_comparison(cmp) + self.assertIn("Comparison:", text) + self.assertIn("Dominant signal:", text) + self.assertIn("Rank order:", text) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_fusion.py b/tests/test_fusion.py new file mode 100644 index 0000000..c8a613e --- /dev/null +++ b/tests/test_fusion.py @@ -0,0 +1,121 @@ +import unittest + +import bb25 as bb + + +class TestGating(unittest.TestCase): + """Test gating parameter in log_odds_conjunction.""" + + def test_no_gating_default(self): + result = bb.log_odds_conjunction([0.8, 0.8]) + self.assertGreater(result, 0.5) + + def test_relu_gating(self): + # Relu zeros out negative logits (prob < 0.5) + result_none = bb.log_odds_conjunction([0.8, 0.3], gating="none") + result_relu = bb.log_odds_conjunction([0.8, 0.3], gating="relu") + # With relu, the negative logit from 0.3 becomes 0, so result should be higher + self.assertGreater(result_relu, result_none) + + def test_swish_gating(self): + result_none = bb.log_odds_conjunction([0.8, 0.3], gating="none") + result_swish = bb.log_odds_conjunction([0.8, 0.3], gating="swish") + # Swish is softer than relu, should still increase relative to none + self.assertGreater(result_swish, result_none) + + def test_gating_with_weights(self): + probs = [0.8, 0.3] + weights = [0.6, 0.4] + result = bb.log_odds_conjunction(probs, weights=weights, gating="relu") + self.assertGreater(result, 0.0) + self.assertLess(result, 1.0) + + def test_invalid_gating(self): + with self.assertRaises(ValueError): + bb.log_odds_conjunction([0.8, 0.8], gating="invalid") + + +class TestLearnableLogOddsWeights(unittest.TestCase): + """Test LearnableLogOddsWeights.""" + + def test_creation(self): + w = bb.LearnableLogOddsWeights(3) + self.assertEqual(w.n_signals, 3) + self.assertAlmostEqual(w.alpha, 0.0) + # Uniform initialization + weights = w.weights + self.assertEqual(len(weights), 3) + for wi in weights: + self.assertAlmostEqual(wi, 1.0 / 3, places=6) + + def test_combine(self): + w = bb.LearnableLogOddsWeights(2) + result = w.combine([0.8, 0.7]) + self.assertGreater(result, 0.0) + self.assertLess(result, 1.0) + + def test_fit(self): + w = bb.LearnableLogOddsWeights(2) + probs = [[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8]] + labels = [1.0, 0.0, 1.0, 0.0] + w.fit(probs, labels) + # First signal should have higher weight after training + weights = w.weights + self.assertGreater(weights[0], weights[1]) + + def test_update(self): + w = bb.LearnableLogOddsWeights(2) + for _ in range(20): + w.update([[0.9, 0.1]], [1.0]) + w.update([[0.1, 0.9]], [0.0]) + # Averaged weights should be available + avg = w.averaged_weights + self.assertEqual(len(avg), 2) + + +class TestAttentionLogOddsWeights(unittest.TestCase): + """Test AttentionLogOddsWeights.""" + + def test_creation(self): + a = bb.AttentionLogOddsWeights(2, 3) + self.assertEqual(a.n_signals, 2) + self.assertEqual(a.n_query_features, 3) + self.assertAlmostEqual(a.alpha, 0.5) + self.assertFalse(a.normalize) + + def test_combine(self): + a = bb.AttentionLogOddsWeights(2, 3) + probs = [0.8, 0.7] + qf = [1.0, 0.0, 0.5] + result = a.combine(probs, 1, qf, 1) + self.assertEqual(len(result), 1) + self.assertGreater(result[0], 0.0) + self.assertLess(result[0], 1.0) + + def test_combine_batched(self): + a = bb.AttentionLogOddsWeights(2, 3) + # 3 candidates, 2 signals each + probs = [0.8, 0.7, 0.6, 0.5, 0.9, 0.1] + qf = [1.0, 0.0, 0.5] # single query + result = a.combine(probs, 3, qf, 1) + self.assertEqual(len(result), 3) + + def test_fit(self): + a = bb.AttentionLogOddsWeights(2, 2) + # 4 samples, 2 signals each + probs = [0.9, 0.1, 0.1, 0.9, 0.8, 0.2, 0.2, 0.8] + labels = [1.0, 0.0, 1.0, 0.0] + qf = [1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0] + a.fit(probs, labels, qf, 4) + + def test_normalize(self): + a = bb.AttentionLogOddsWeights(2, 3, normalize=True) + self.assertTrue(a.normalize) + probs = [0.8, 0.7, 0.6, 0.5] + qf = [1.0, 0.0, 0.5] + result = a.combine(probs, 2, qf, 1) + self.assertEqual(len(result), 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..988fb5d --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,78 @@ +import unittest + +import bb25 as bb + + +class TestBrierScore(unittest.TestCase): + def test_perfect_predictions(self): + probs = [1.0, 0.0, 1.0, 0.0] + labels = [1.0, 0.0, 1.0, 0.0] + self.assertAlmostEqual(bb.brier_score(probs, labels), 0.0) + + def test_worst_predictions(self): + probs = [0.0, 1.0] + labels = [1.0, 0.0] + self.assertAlmostEqual(bb.brier_score(probs, labels), 1.0) + + def test_constant_prediction(self): + probs = [0.5, 0.5, 0.5, 0.5] + labels = [1.0, 0.0, 1.0, 0.0] + self.assertAlmostEqual(bb.brier_score(probs, labels), 0.25) + + +class TestECE(unittest.TestCase): + def test_perfect_calibration(self): + probs = [0.0, 1.0] + labels = [0.0, 1.0] + ece = bb.expected_calibration_error(probs, labels) + self.assertAlmostEqual(ece, 0.0, places=5) + + def test_ece_non_negative(self): + probs = [0.1, 0.4, 0.6, 0.9] + labels = [0.0, 0.0, 1.0, 1.0] + ece = bb.expected_calibration_error(probs, labels) + self.assertGreaterEqual(ece, 0.0) + self.assertLessEqual(ece, 1.0) + + +class TestReliabilityDiagram(unittest.TestCase): + def test_basic(self): + probs = [0.1, 0.2, 0.3, 0.7, 0.8, 0.9] + labels = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0] + bins = bb.reliability_diagram(probs, labels, 5) + self.assertGreater(len(bins), 0) + for avg_pred, avg_actual, count in bins: + self.assertGreaterEqual(count, 1) + self.assertGreaterEqual(avg_pred, 0.0) + self.assertLessEqual(avg_pred, 1.0) + + def test_custom_bins(self): + probs = [0.1, 0.9] + labels = [0.0, 1.0] + bins = bb.reliability_diagram(probs, labels, 2) + self.assertEqual(len(bins), 2) + + +class TestCalibrationReport(unittest.TestCase): + def test_report_fields(self): + probs = [0.1, 0.2, 0.3, 0.7, 0.8, 0.9] + labels = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0] + report = bb.calibration_report(probs, labels) + self.assertEqual(report.n_samples, 6) + self.assertEqual(report.n_bins, 10) + self.assertGreaterEqual(report.ece, 0.0) + self.assertGreaterEqual(report.brier, 0.0) + + def test_summary_format(self): + probs = [0.1, 0.2, 0.3, 0.7, 0.8, 0.9] + labels = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0] + report = bb.calibration_report(probs, labels) + summary = report.summary() + self.assertIn("Calibration Report", summary) + self.assertIn("ECE", summary) + self.assertIn("Brier", summary) + self.assertIn("Reliability Diagram", summary) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_probability.py b/tests/test_probability.py new file mode 100644 index 0000000..d6b2773 --- /dev/null +++ b/tests/test_probability.py @@ -0,0 +1,165 @@ +import unittest + +import bb25 as bb + + +class TestBayesianProbabilityTransform(unittest.TestCase): + """Test BayesianProbabilityTransform.""" + + def test_basic_creation(self): + t = bb.BayesianProbabilityTransform() + self.assertAlmostEqual(t.alpha, 1.0) + self.assertAlmostEqual(t.beta, 0.0) + self.assertIsNone(t.base_rate) + + def test_custom_params(self): + t = bb.BayesianProbabilityTransform(alpha=2.0, beta=1.5, base_rate=0.1) + self.assertAlmostEqual(t.alpha, 2.0) + self.assertAlmostEqual(t.beta, 1.5) + self.assertAlmostEqual(t.base_rate, 0.1) + + def test_likelihood(self): + t = bb.BayesianProbabilityTransform(alpha=1.0, beta=0.0) + # At score=0 with beta=0: sigmoid(0) = 0.5 + self.assertAlmostEqual(t.likelihood(0.0), 0.5, places=6) + # Positive score -> probability > 0.5 + self.assertGreater(t.likelihood(1.0), 0.5) + # Negative score -> probability < 0.5 + self.assertLess(t.likelihood(-1.0), 0.5) + + def test_tf_prior(self): + self.assertAlmostEqual(bb.BayesianProbabilityTransform.tf_prior(0.0), 0.2) + self.assertAlmostEqual(bb.BayesianProbabilityTransform.tf_prior(10.0), 0.9) + # Saturates at tf=10 + self.assertAlmostEqual( + bb.BayesianProbabilityTransform.tf_prior(20.0), 0.9 + ) + + def test_norm_prior(self): + # Peaks at ratio=0.5 + self.assertAlmostEqual(bb.BayesianProbabilityTransform.norm_prior(0.5), 0.9) + # Falls at extremes + p_at_0 = bb.BayesianProbabilityTransform.norm_prior(0.0) + self.assertAlmostEqual(p_at_0, 0.3) + + def test_composite_prior(self): + prior = bb.BayesianProbabilityTransform.composite_prior(5.0, 0.5) + self.assertGreaterEqual(prior, 0.1) + self.assertLessEqual(prior, 0.9) + + def test_posterior(self): + # High likelihood + high prior -> high posterior + p = bb.BayesianProbabilityTransform.posterior(0.9, 0.8) + self.assertGreater(p, 0.9) + + # With base_rate + p_br = bb.BayesianProbabilityTransform.posterior(0.9, 0.8, base_rate=0.1) + self.assertLess(p_br, p) + + def test_score_to_probability(self): + t = bb.BayesianProbabilityTransform(alpha=1.0, beta=0.0) + p = t.score_to_probability(1.0, 5.0, 0.5) + self.assertGreater(p, 0.0) + self.assertLess(p, 1.0) + + def test_invalid_base_rate(self): + with self.assertRaises(ValueError): + bb.BayesianProbabilityTransform(base_rate=0.0) + with self.assertRaises(ValueError): + bb.BayesianProbabilityTransform(base_rate=1.0) + + +class TestWAND(unittest.TestCase): + """Test WAND upper bound.""" + + def test_wand_upper_bound(self): + t = bb.BayesianProbabilityTransform(alpha=1.0, beta=0.0) + ub = t.wand_upper_bound(5.0) + self.assertGreater(ub, 0.0) + self.assertLessEqual(ub, 1.0) + + def test_wand_monotonic(self): + t = bb.BayesianProbabilityTransform(alpha=1.0, beta=0.0) + ub_low = t.wand_upper_bound(1.0) + ub_high = t.wand_upper_bound(5.0) + self.assertGreater(ub_high, ub_low) + + def test_wand_custom_p_max(self): + t = bb.BayesianProbabilityTransform(alpha=1.0, beta=0.0) + ub_default = t.wand_upper_bound(3.0) + ub_low = t.wand_upper_bound(3.0, p_max=0.5) + self.assertGreater(ub_default, ub_low) + + +class TestFitBalanced(unittest.TestCase): + """Test batch fitting with balanced mode.""" + + def test_fit_balanced(self): + t = bb.BayesianProbabilityTransform(alpha=1.0, beta=0.0) + scores = [0.5, 1.0, 2.0, 3.0, 0.1, 0.2] + labels = [0.0, 0.0, 1.0, 1.0, 0.0, 0.0] + t.fit(scores, labels, mode="balanced") + self.assertNotAlmostEqual(t.alpha, 1.0, places=2) + + def test_fit_convergence(self): + t = bb.BayesianProbabilityTransform(alpha=1.0, beta=0.0) + scores = [0.1, 0.5, 1.5, 3.0, 4.0, 5.0] + labels = [0.0, 0.0, 0.0, 1.0, 1.0, 1.0] + t.fit(scores, labels, learning_rate=0.1, max_iterations=2000) + # High scores should get high probability + self.assertGreater(t.likelihood(5.0), 0.8) + # Low scores should get low probability + self.assertLess(t.likelihood(0.1), 0.3) + + +class TestPriorAware(unittest.TestCase): + """Test prior-aware training mode.""" + + def test_fit_prior_aware(self): + t = bb.BayesianProbabilityTransform(alpha=1.0, beta=0.0) + scores = [0.5, 1.0, 2.0, 3.0] + labels = [0.0, 0.0, 1.0, 1.0] + tfs = [1.0, 2.0, 5.0, 8.0] + dlrs = [0.5, 0.7, 0.4, 0.6] + t.fit(scores, labels, mode="prior_aware", tfs=tfs, doc_len_ratios=dlrs) + + +class TestOnlineUpdate(unittest.TestCase): + """Test online SGD updates.""" + + def test_update_online(self): + t = bb.BayesianProbabilityTransform(alpha=1.0, beta=0.0) + for _ in range(10): + t.update([2.0], [1.0]) + t.update([0.1], [0.0]) + self.assertIsNotNone(t.averaged_alpha) + self.assertIsNotNone(t.averaged_beta) + + def test_update_moves_params(self): + t = bb.BayesianProbabilityTransform(alpha=1.0, beta=0.0) + alpha_before = t.alpha + for _ in range(50): + t.update([5.0], [1.0]) + t.update([0.0], [0.0]) + # Params should have moved + self.assertNotAlmostEqual(t.alpha, alpha_before, places=2) + + def test_averaged_params_smooth(self): + t = bb.BayesianProbabilityTransform(alpha=1.0, beta=0.0) + raw_alphas = [] + avg_alphas = [] + for _ in range(100): + t.update([3.0], [1.0]) + t.update([0.1], [0.0]) + raw_alphas.append(t.alpha) + avg_alphas.append(t.averaged_alpha) + # Averaged alpha should have less variance than raw alpha + raw_diffs = [abs(raw_alphas[i+1] - raw_alphas[i]) for i in range(len(raw_alphas)-1)] + avg_diffs = [abs(avg_alphas[i+1] - avg_alphas[i]) for i in range(len(avg_alphas)-1)] + raw_var = sum(d * d for d in raw_diffs) / len(raw_diffs) + avg_var = sum(d * d for d in avg_diffs) / len(avg_diffs) + self.assertGreater(raw_var, avg_var) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_smoke.py b/tests/test_smoke.py index f93d643..6f65f04 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -6,7 +6,7 @@ class SmokeTests(unittest.TestCase): def test_run_experiments(self): results = bb.run_experiments() - self.assertEqual(len(results), 10) + self.assertEqual(len(results), 13) self.assertTrue(all(r.passed for r in results)) def test_default_builders(self):