Skip to content

Weight structured sparsity in training#3561

Draft
BirdsOfAFthr wants to merge 1 commit intomainfrom
amandaliang
Draft

Weight structured sparsity in training#3561
BirdsOfAFthr wants to merge 1 commit intomainfrom
amandaliang

Conversation

@BirdsOfAFthr
Copy link
Copy Markdown
Collaborator

Description

  • Enable Structured Sparsity in Qwix:

    • Added get_fp8_full_qwix_rule_w_sparsity in quantizations.py to configure N:M sparsity (N, M, and update steps) via additional_qt_config.
  • Support Non-Differentiable State (batch_stats):

    • Updated loss_fn, train_step, and eval_step in train.py to accept and pass batch_stats.
    • Wrapped model application variables to include {"params": params, "batch_stats": sparsity_state}.
    • Ensured train_step filters gradients so that they are only applied to differentiable parameters, while batch_stats are updated from the model's auxiliary outputs.
  • Safe Partial Checkpoint Loading:

    • Updated setup_initial_state in maxtext_utils.py to intelligently merge loaded parameters. It now ignores jax.ShapeDtypeStruct placeholders (which indicate un-loaded variables) and falls back to initialized parameters.

Tests

  • Added internal unit tests (will be in the G3 change)
  • E2E benchmarking

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@BirdsOfAFthr BirdsOfAFthr changed the title Integrate weight structured sparsity in training Weight structured sparsity in training Apr 2, 2026
@BirdsOfAFthr BirdsOfAFthr force-pushed the amandaliang branch 2 times, most recently from f876cb6 to d3c4615 Compare April 4, 2026 02:42
@BirdsOfAFthr BirdsOfAFthr marked this pull request as ready for review April 4, 2026 02:45
@BirdsOfAFthr BirdsOfAFthr marked this pull request as draft April 4, 2026 02:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant