Skip to content

Finish Metropolis-Hastings sampling for JAX backend#1580

Open
Copilot wants to merge 6 commits intomainfrom
copilot/finish-metropolis-hastings-jax
Open

Finish Metropolis-Hastings sampling for JAX backend#1580
Copilot wants to merge 6 commits intomainfrom
copilot/finish-metropolis-hastings-jax

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Mar 31, 2026

MH sampling was explicitly blocked on JAX (assert backend != "jax") with no implementation. This completes the JAX path end-to-end.

JAX random state API

  • Renamed internal get_state_get_state in jax/random.py; exposed public get_state/set_state aliases so MH can split/update the global PRNG key
  • Implemented multinomial for JAX via jax.random.categorical + jnp.bincount (was previously a no-op returning NotImplementedError)
  • Added get_state/set_state to numpy and pytorch backends for API symmetry

Core MH sampler

Replaced the JAX-blocking assertion with a JAX-specific code path in sample_metropolis_hastings. Uses a plain Python loop rather than lax.scan so that scipy-based pdf methods (e.g. HypertoroidalWrappedNormalDistribution) remain compatible:

# JAX proposals must now accept (key, x) → x_prop
def proposal(key, x):
    key, subkey = jax.random.split(key)
    return jnp.mod(x + jax.random.normal(subkey, x.shape), 2 * jnp.pi)

samples = dist.sample_metropolis_hastings(n=100, proposal=proposal, start_point=x0)

Added _assert_proposal_supports_key() to give a clear error when a single-argument numpy-style proposal is passed with the JAX backend.

JAX-compatible default proposals

Added (key, x) → x_prop default proposals to:

  • AbstractHypersphericalDistribution — uniform on S^d via normal + normalize (with explicit re-normalization to avoid float32 drift)
  • AbstractHyperhemisphericalDistribution — same, then flip sign if last coord < 0
  • AbstractHypertoroidalDistribution — wrapped Gaussian step

Mixture sampling

AbstractMixture.sample now has a JAX branch that iterates components and falls back from sample() to sample_metropolis_hastings() when the direct sampler is unsupported (catches NotImplementedError, AssertionError, ValueError, TypeError).

Other branch changes

  • apply_function: renamed parameter function_is_vectorizedf_supports_multiple; removed beartype decorator
  • AbstractSphericalHarmonicsDistribution.normalize_in_place: removed unused warn_unnorm kwarg
  • CircularFourierDistribution arithmetic: preserve n after +/- operations
  • PartiallyWrappedNormalDistribution.set_mode: simplified to direct assignment
  • Test files: removed JAX skip decorators from test_sample_metropolis_hastings_basics_only_{t2,s2,h2}; cleaned up redundant warnings.catch_warnings blocks

Copilot AI and others added 2 commits March 31, 2026 12:58
…er, remove duplicate comment

Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/17115f24-1242-4133-aeaa-40db3ad34b71

Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com>
@github-actions
Copy link
Copy Markdown

github-actions bot commented Mar 31, 2026

MegaLinter analysis: Error

Descriptor Linter Files Fixed Errors Warnings Elapsed time
✅ COPYPASTE jscpd yes no no 7.71s
✅ JSON prettier 2 0 0 0 0.4s
✅ JSON v8r 2 0 0 2.78s
✅ MARKDOWN markdownlint 1 0 0 0 0.68s
✅ MARKDOWN markdown-table-formatter 1 0 0 0 0.21s
✅ PYTHON bandit 249 0 0 3.11s
✅ PYTHON black 249 7 0 0 4.72s
❌ PYTHON flake8 249 3 0 1.74s
✅ PYTHON isort 249 7 0 0 0.49s
✅ PYTHON mypy 249 0 0 3.94s
❌ PYTHON pylint 249 3 0 70.19s
✅ PYTHON ruff 249 9 0 0 0.05s
✅ REPOSITORY checkov yes no no 21.47s
✅ REPOSITORY gitleaks yes no no 4.04s
✅ REPOSITORY git_diff yes no no 0.04s
✅ REPOSITORY secretlint yes no no 6.31s
✅ REPOSITORY syft yes no no 3.8s
✅ REPOSITORY trivy-sbom yes no no 1.79s
✅ REPOSITORY trufflehog yes no no 15.21s
✅ YAML prettier 4 0 0 0 0.47s
✅ YAML v8r 4 0 0 5.57s
✅ YAML yamllint 4 0 0 0.46s

Detailed Issues

❌ PYTHON / flake8 - 3 errors
pyrecest/distributions/abstract_mixture.py:92:24: E203 whitespace before ':'
pyrecest/tests/distributions/test_spherical_harmonics_distribution_complex.py:2:1: F401 'warnings' imported but unused
pyrecest/tests/test_evaluation_basic.py:5:1: F401 'warnings' imported but unused
❌ PYTHON / pylint - 3 errors
************* Module pyrecest.tests.distributions.test_spherical_harmonics_distribution_complex
pyrecest/tests/distributions/test_spherical_harmonics_distribution_complex.py:2:0: W0611: Unused import warnings (unused-import)
************* Module pyrecest.tests.test_evaluation_basic
pyrecest/tests/test_evaluation_basic.py:5:0: W0611: Unused import warnings (unused-import)
************* Module update_init_helper
update_init_helper.py:1:0: R0801: Similar lines in 2 files
==pyrecest.distributions.hypersphere_subset.abstract_hyperhemispherical_distribution:[67:78]
==pyrecest.distributions.hypersphere_subset.abstract_hyperspherical_distribution:[80:90]
                proposal = proposal_np
            else:
                import jax as _jax  # pylint: disable=import-error
                import jax.numpy as _jnp  # pylint: disable=import-error

                def proposal_jax(key, _):
                    """JAX independence proposal: uniform on hypersphere."""
                    key, subkey = _jax.random.split(key)
                    s = _jax.random.normal(subkey, shape=(1, self.dim + 1))
                    # Ensure exact unit norm (avoids float32 rounding errors) (duplicate-code)

------------------------------------
Your code has been rated at 10.00/10

See detailed reports in MegaLinter artifacts

Your project could benefit from a custom flavor, which would allow you to run only the linters you need, and thus improve runtime performances. (Skip this info by defining FLAVOR_SUGGESTIONS: false)

  • Documentation: Custom Flavors
  • Command: npx mega-linter-runner@9.4.0 --custom-flavor-setup --custom-flavor-linters PYTHON_PYLINT,PYTHON_BLACK,PYTHON_FLAKE8,PYTHON_ISORT,PYTHON_BANDIT,PYTHON_MYPY,PYTHON_RUFF,COPYPASTE_JSCPD,JSON_V8R,JSON_PRETTIER,MARKDOWN_MARKDOWNLINT,MARKDOWN_MARKDOWN_TABLE_FORMATTER,REPOSITORY_CHECKOV,REPOSITORY_GIT_DIFF,REPOSITORY_GITLEAKS,REPOSITORY_SECRETLINT,REPOSITORY_SYFT,REPOSITORY_TRIVY_SBOM,REPOSITORY_TRUFFLEHOG,YAML_PRETTIER,YAML_YAMLLINT,YAML_V8R

MegaLinter is graciously provided by OX Security
Show us your support by starring ⭐ the repository

@FlorianPfaff
Copy link
Copy Markdown
Owner

@copilot Can you fix the issues that Megalinter complains about?

@FlorianPfaff
Copy link
Copy Markdown
Owner

@copilot Can you fix the issues that Megalinter complains about?

@FlorianPfaff
Copy link
Copy Markdown
Owner

@copilot Can you fix the issues that Megalinter complains about?

auto-merge was automatically disabled March 31, 2026 15:03

Head branch was pushed to by a user without write access

Copy link
Copy Markdown
Contributor Author

Copilot AI commented Mar 31, 2026

@copilot Can you fix the issues that Megalinter complains about?

Fixed and pushed in commit b911ffa.

I addressed the MegaLinter failures by resolving the flake8 and pylint issues in the MH/JAX-related files with minimal code changes, then re-ran targeted lint checks and MH sampling tests (3 passed).

No UI was changed in this PR (backend-only), so there is no UI screenshot to provide.

Copy link
Copy Markdown
Contributor Author

Copilot AI commented Mar 31, 2026

@copilot Can you fix the issues that Megalinter complains about?

Addressed in commit b911ffa.

The MegaLinter issues previously reported (flake8/pylint in the MH/JAX files) were fixed there, and targeted lint/test validation passed locally. The latest MegaLinter run for that commit is action_required with no jobs executed, so there are no new linter diagnostics in that run to fix.

@FlorianPfaff FlorianPfaff enabled auto-merge March 31, 2026 15:15
@github-actions
Copy link
Copy Markdown

Test Results

    8 files      8 suites   1h 31m 5s ⏱️
  533 tests   533 ✅   0 💤 0 ❌
4 456 runs  4 196 ✅ 260 💤 0 ❌

Results for commit 8c0ec32.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants