Skip to content

Standardize SFT-Tunix integration and add manual step timing#3576

Draft
igorts-git wants to merge 1 commit intomainfrom
igorts/sft-tunix-standardization
Draft

Standardize SFT-Tunix integration and add manual step timing#3576
igorts-git wants to merge 1 commit intomainfrom
igorts/sft-tunix-standardization

Conversation

@igorts-git
Copy link
Copy Markdown
Collaborator

Description

This PR refactors the SFT integration with the Tunix library. This change is specifically required to support models with complex sharding rules (like Gemma2) and to fix a universal crash in the metric reporting path.

Key Changes

  1. Standardized Model Interface (train_sft.py)
  • Switched to using the official TunixMaxTextAdapter. This eliminates ad-hoc naming wrappers and provides a consistent bridge for argument names (e.g., mapping generic positions to MaxText's decoder_positions).
  • Implemented Manual Pre-sharding of the optimizer. We now shard the optimizer state within MaxText's full JAX context before passing it to Tunix.
  • Enabled the skip_sharding_optimizer flag. This tells the Tunix trainer to respect MaxText’s existing sharding layout rather than attempting to re-shard, which previously caused crashes on scalar values (like the step counter).
  • Wrapped the PeftTrainer lifecycle in the necessary JAX contexts (jax.set_mesh and nn.logical_axis_rules).
  1. Stable Metric Reporting (hooks.py)
  • Implemented Manual Step Timing. Previously, the integration relied on a hardcoded 0.0 value passed from the trainer, which triggered a ZeroDivisionError in the MaxText MetricLogger when calculating TFLOPs.
  • Added a step_start_time tracker to the hooks to ensure accurate, non-zero reporting of step durations.

Why this is needed

  • Gemma2 Support: Gemma2 uses a logical 'norm' axis. Without the correct context and manual sharding established in this PR, JAX crashes because it cannot map this logical axis to physical hardware during the trainer's initialization.
  • Stability: Any SFT run shorter than the rampup_end_step would previously crash on the first step due to the division-by-zero bug in the logger.

Tests

  • Model: gemma2-2b (NNX version) on a local TPU v4 VM
  • Result: Successfully completed 2-step training and evaluation cycles with correct metric logging and zero sharding errors.

Note: This PR assumes the corresponding change in the Tunix library (adding the skip_sharding_optimizer flag) is present in the environment.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov bot commented Apr 6, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant