Skip to content

Port advanced fusion, calibration, and debugging from Python to Rust#2

Open
jaepil wants to merge 3 commits intoinstructkr:mainfrom
jaepil:main
Open

Port advanced fusion, calibration, and debugging from Python to Rust#2
jaepil wants to merge 3 commits intoinstructkr:mainfrom
jaepil:main

Conversation

@jaepil
Copy link
Contributor

@jaepil jaepil commented Mar 1, 2026

Summary

Full port of bayesian-bm25 Python features (v0.4.0-v0.8.0) to Rust with Python bindings.

Commit 1: Port 5 high-priority features

  • fusion.rs: cosine_to_probability, prob_not, prob_and, prob_or, log_odds_conjunction (paper Eq. 20/23, Theorem 8.3), balanced_log_odds_fusion
  • Fix log-odds conjunction formula: replace incorrect logit(geo_mean) + alpha*ln(n) with correct mean(logit(p_i)) * n^alpha
  • Add base_rate prior to BayesianBM25Scorer (Remark 4.4.5)
  • Refactor BayesianBM25Scorer::score() to use fusion::prob_or
  • Add experiments 11-13 (base rate, log-odds properties, fusion primitives)

Commit 2: Benchmark runner improvements

  • Add --embedding-model option for live sentence-transformers encoding
  • Add balanced_log_odds_fusion scorer alongside hybrid_or/hybrid_and
  • Fix O(QDQ) hybrid scoring bottleneck to O(Q*D)

Commit 3: Advanced modules

  • probability.rs: BayesianProbabilityTransform with batch fit() (SGD) and online update() with Polyak averaging
  • learnable_weights.rs: LearnableLogOddsWeights -- per-signal reliability weights via softmax + Hebbian gradient
  • attention_weights.rs: AttentionLogOddsWeights -- query-dependent signal weighting via learned attention
  • metrics.rs: ECE, Brier score, reliability diagrams, CalibrationReport
  • debug.rs: FusionDebugger with trace_bm25, trace_vector, trace_fusion, compare
  • fusion.rs: Add relu/swish gating to log_odds_conjunction
  • Full Python bindings for all new types and functions
  • Benchmark runner updated with all 10 scorers + calibration diagnostics

Benchmark results (SQuAD 100 queries, BGE-M3, 2067 docs)

Scorer NDCG@10 MAP@10 MRR@10
bm25 0.7260 0.7017 0.7017
bayesian 0.8255 0.7944 0.7944
bayesian_fitted 0.8333 0.8049 0.8049
hybrid_or 0.8313 0.8019 0.8019
hybrid_and 0.8350 0.8069 0.8069
balanced_fusion 0.8828 0.8474 0.8474
gated_relu 0.8350 0.8069 0.8069
gated_swish 0.8350 0.8069 0.8069
learned_weights 0.8313 0.8019 0.8019
attention 0.8333 0.7923 0.7923

Test plan

  • All 13 Rust experiments pass (cargo run -- --experiment all)
  • 69 Python tests pass across 4 test modules + smoke tests
  • Benchmark runner verified on SQuAD (BGE-M3) and SciFact (MiniLM-L6-v2)

- Add fusion.rs with standalone functions: cosine_to_probability,
  prob_not, prob_and, prob_or, log_odds_conjunction (paper Eq. 20/23,
  Theorem 8.3), and balanced_log_odds_fusion
- Fix log-odds conjunction formula in HybridScorer: replace incorrect
  logit(geo_mean) + alpha*ln(n) with correct mean(logit(p_i)) * n^alpha
- Add base_rate prior field to BayesianBM25Scorer with two-step Bayes
  update (Remark 4.4.5)
- Refactor BayesianBM25Scorer::score() to use fusion::prob_or
- Add Python bindings for all new fusion functions with validation
- Add experiments 11 (base rate prior), 12 (log-odds conjunction
  properties), 13 (fusion primitives); all 13 experiments pass
- Fix benchmark runner to skip TSV header rows in qrels files
- Add --embedding-model option for live sentence-transformers encoding
- Add balanced_log_odds_fusion scorer to benchmark alongside hybrid_or/hybrid_and
- Add evaluate_hybrid() to pass query embeddings directly (O(Q*D) vs O(Q*D*Q))
- Add evaluate_balanced_fusion() for batch logit-space fusion scoring
- Use dataclass defaults for DocRecord.embedding and QueryRecord fields
…to Rust

Add 5 new Rust modules with full Python bindings:
- probability: BayesianProbabilityTransform with batch fit and online SGD update
- learnable_weights: LearnableLogOddsWeights with Hebbian gradient learning
- attention_weights: AttentionLogOddsWeights with query-dependent signal weighting
- metrics: ECE, Brier score, reliability diagrams, CalibrationReport
- debug: FusionDebugger with trace_bm25, trace_vector, trace_fusion, compare

Also adds relu/swish gating to log_odds_conjunction and cosine_to_probability.

Update benchmark runner to evaluate all new algorithms (gated fusion,
learned weights, attention weights, fitted Bayesian, calibration diagnostics).

Split tests into per-module files: test_probability, test_fusion, test_metrics,
test_debug (69 tests total).
@jaepil jaepil changed the title Port 5 high-priority features and add balanced fusion benchmark Port advanced fusion, calibration, and debugging from Python to Rust Mar 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant