Finish Metropolis-Hastings sampling for JAX backend#1580
Finish Metropolis-Hastings sampling for JAX backend#1580
Conversation
Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/17115f24-1242-4133-aeaa-40db3ad34b71 Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com>
…er, remove duplicate comment Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/17115f24-1242-4133-aeaa-40db3ad34b71 Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com>
❌MegaLinter analysis: Error
Detailed Issues❌ PYTHON / flake8 - 3 errors❌ PYTHON / pylint - 3 errorsSee detailed reports in MegaLinter artifacts Your project could benefit from a custom flavor, which would allow you to run only the linters you need, and thus improve runtime performances. (Skip this info by defining
|
|
@copilot Can you fix the issues that Megalinter complains about? |
|
@copilot Can you fix the issues that Megalinter complains about? |
|
@copilot Can you fix the issues that Megalinter complains about? |
Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/afb8c469-6fe3-4c7b-8e3a-c653902815d2 Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com>
Head branch was pushed to by a user without write access
Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/afb8c469-6fe3-4c7b-8e3a-c653902815d2 Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com>
Fixed and pushed in commit I addressed the MegaLinter failures by resolving the flake8 and pylint issues in the MH/JAX-related files with minimal code changes, then re-ran targeted lint checks and MH sampling tests ( No UI was changed in this PR (backend-only), so there is no UI screenshot to provide. |
Addressed in commit The MegaLinter issues previously reported (flake8/pylint in the MH/JAX files) were fixed there, and targeted lint/test validation passed locally. The latest MegaLinter run for that commit is |
Test Results 8 files 8 suites 1h 31m 5s ⏱️ Results for commit 8c0ec32. |

MH sampling was explicitly blocked on JAX (
assert backend != "jax") with no implementation. This completes the JAX path end-to-end.JAX random state API
get_state→_get_stateinjax/random.py; exposed publicget_state/set_statealiases so MH can split/update the global PRNG keymultinomialfor JAX viajax.random.categorical+jnp.bincount(was previously a no-op returningNotImplementedError)get_state/set_stateto numpy and pytorch backends for API symmetryCore MH sampler
Replaced the JAX-blocking assertion with a JAX-specific code path in
sample_metropolis_hastings. Uses a plain Python loop rather thanlax.scanso that scipy-basedpdfmethods (e.g.HypertoroidalWrappedNormalDistribution) remain compatible:Added
_assert_proposal_supports_key()to give a clear error when a single-argument numpy-style proposal is passed with the JAX backend.JAX-compatible default proposals
Added
(key, x) → x_propdefault proposals to:AbstractHypersphericalDistribution— uniform on S^d via normal + normalize (with explicit re-normalization to avoid float32 drift)AbstractHyperhemisphericalDistribution— same, then flip sign if last coord < 0AbstractHypertoroidalDistribution— wrapped Gaussian stepMixture sampling
AbstractMixture.samplenow has a JAX branch that iterates components and falls back fromsample()tosample_metropolis_hastings()when the direct sampler is unsupported (catchesNotImplementedError,AssertionError,ValueError,TypeError).Other branch changes
apply_function: renamed parameterfunction_is_vectorized→f_supports_multiple; removedbeartypedecoratorAbstractSphericalHarmonicsDistribution.normalize_in_place: removed unusedwarn_unnormkwargCircularFourierDistributionarithmetic: preservenafter+/-operationsPartiallyWrappedNormalDistribution.set_mode: simplified to direct assignmenttest_sample_metropolis_hastings_basics_only_{t2,s2,h2}; cleaned up redundantwarnings.catch_warningsblocks