Port advanced fusion, calibration, and debugging from Python to Rust#2
Open
jaepil wants to merge 3 commits intoinstructkr:mainfrom
Open
Port advanced fusion, calibration, and debugging from Python to Rust#2jaepil wants to merge 3 commits intoinstructkr:mainfrom
jaepil wants to merge 3 commits intoinstructkr:mainfrom
Conversation
- Add fusion.rs with standalone functions: cosine_to_probability, prob_not, prob_and, prob_or, log_odds_conjunction (paper Eq. 20/23, Theorem 8.3), and balanced_log_odds_fusion - Fix log-odds conjunction formula in HybridScorer: replace incorrect logit(geo_mean) + alpha*ln(n) with correct mean(logit(p_i)) * n^alpha - Add base_rate prior field to BayesianBM25Scorer with two-step Bayes update (Remark 4.4.5) - Refactor BayesianBM25Scorer::score() to use fusion::prob_or - Add Python bindings for all new fusion functions with validation - Add experiments 11 (base rate prior), 12 (log-odds conjunction properties), 13 (fusion primitives); all 13 experiments pass - Fix benchmark runner to skip TSV header rows in qrels files
- Add --embedding-model option for live sentence-transformers encoding - Add balanced_log_odds_fusion scorer to benchmark alongside hybrid_or/hybrid_and - Add evaluate_hybrid() to pass query embeddings directly (O(Q*D) vs O(Q*D*Q)) - Add evaluate_balanced_fusion() for batch logit-space fusion scoring - Use dataclass defaults for DocRecord.embedding and QueryRecord fields
…to Rust Add 5 new Rust modules with full Python bindings: - probability: BayesianProbabilityTransform with batch fit and online SGD update - learnable_weights: LearnableLogOddsWeights with Hebbian gradient learning - attention_weights: AttentionLogOddsWeights with query-dependent signal weighting - metrics: ECE, Brier score, reliability diagrams, CalibrationReport - debug: FusionDebugger with trace_bm25, trace_vector, trace_fusion, compare Also adds relu/swish gating to log_odds_conjunction and cosine_to_probability. Update benchmark runner to evaluate all new algorithms (gated fusion, learned weights, attention weights, fitted Bayesian, calibration diagnostics). Split tests into per-module files: test_probability, test_fusion, test_metrics, test_debug (69 tests total).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Full port of bayesian-bm25 Python features (v0.4.0-v0.8.0) to Rust with Python bindings.
Commit 1: Port 5 high-priority features
fusion.rs:cosine_to_probability,prob_not,prob_and,prob_or,log_odds_conjunction(paper Eq. 20/23, Theorem 8.3),balanced_log_odds_fusionlogit(geo_mean) + alpha*ln(n)with correctmean(logit(p_i)) * n^alphabase_rateprior toBayesianBM25Scorer(Remark 4.4.5)BayesianBM25Scorer::score()to usefusion::prob_orCommit 2: Benchmark runner improvements
--embedding-modeloption for live sentence-transformers encodingbalanced_log_odds_fusionscorer alongsidehybrid_or/hybrid_andCommit 3: Advanced modules
BayesianProbabilityTransformwith batchfit()(SGD) and onlineupdate()with Polyak averagingLearnableLogOddsWeights-- per-signal reliability weights via softmax + Hebbian gradientAttentionLogOddsWeights-- query-dependent signal weighting via learned attentionCalibrationReportFusionDebuggerwithtrace_bm25,trace_vector,trace_fusion,comparelog_odds_conjunctionBenchmark results (SQuAD 100 queries, BGE-M3, 2067 docs)
Test plan
cargo run -- --experiment all)