Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions post_training_plan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Brainstorming: MaxText & Tunix Post-Training Integration Plan

## 1. Executive Summary
The goal is to provide a best-in-class post-training suite (SFT, DPO, RLHF, GRPO) that scales to the largest models and TPU slices.

Instead of maintaining duplicate implementations of complex alignment algorithms, we will establish a **"Hybrid Core"** architecture:
* **MaxText** acts as the **Performance Engine**: Providing optimized model implementations (NNX), robust sharding/SPMD rules, and high-throughput data loading.
* **Tunix** acts as the **Algorithmic Orchestrator**: Providing the training loops, specialized loss functions (DPO, PPO), and alignment-specific metrics.

## 2. Shared Responsibilities & Strengths

| Feature | MaxText Strength | Tunix Strength | Recommended Primary |
| :--- | :--- | :--- | :--- |
| **Model Arch** | highly optimized, NNX-based, TPU-aware | research-flexible | **MaxText** |
| **Sharding** | Robust logical-to-physical SPMD rules | Basic/Standard sharding | **MaxText** |
| **Dataloading** | Multi-host Grain integration | HF Datasets convenience | **Collaborative** (MaxText Grain + Tunix Prep) |
| **Loss Functions**| Standard Cross-Entropy | DPO, ORPO, PPO, GRPO | **Tunix** |
| **Metrics** | Goodput, Hardware utilization | KL-Divergence, Rewards, Accuracy | **Tunix** (Loop) + **MaxText** (System) |

## 3. The "Bridge" Architecture (Implementation Strategy)

To make these two projects work together without "technical friction," we should standardize the following interfaces:

### A. The Model Adapter (Unified Naming Bridge)
We discovered that `src/maxtext/integration/tunix/tunix_adapter.py` already contains a robust `TunixMaxTextAdapter`.
* **Current State:** Used effectively in RL (`train_rl.py`).
* **Action:** Refactor SFT (`train_sft.py`) and DPO (`train_dpo.py`) to use this same adapter instead of ad-hoc wrappers. This ensures that any model supported by MaxText is immediately compatible with all Tunix trainers.

### B. Sharding-Aware Initialization
Tunix's `PeftTrainer` currently makes assumptions about sharding that clash with MaxText's more advanced SPMD rules (e.g., the `norm` axis issue and scalar optimizer states).
* **Current State:** Handled via manual "pre-sharding" and no-op overrides in DPO.
* **Action:** Move this logic into a base `MaxTextTunixTrainer` class or a utility function used by all post-training scripts.
* **Action:** Contribute to Tunix to make its internal `_shard_optimizer` check for existing sharding before applying constraints.

### C. Standardized Data Schema (The "Input Bridge")
MaxText's multi-host Grain loader requires numeric arrays, while Tunix often expects strings.
* **Current State:** SFT/DPO/RL each handle this differently.
* **Action:** Standardize on a "Pre-tokenized numeric schema" where MaxText performs tokenization and padding (using DPO-aware left-padding when needed) and provides the `_ids` and `_mask` columns Tunix expects for pre-tokenized input.

## 4. Documentation Strategy

Existing documentation is fragmented (`docs/tutorials/posttraining/sft.md`, `rl.md`, etc.).
* **Action:** Create a unified `post_training_overview.md` that explains the MaxText-Tunix relationship (MaxText=Engine, Tunix=Brain).
* **Action:** Ensure all tutorials consistently mention the `maxtext[tpu-post-train]` installation requirement.

## 4. Collaborative Enhancements (Modifications to Tunix)

To further reduce the "glue code" in MaxText, we should upstream the following improvements to the Tunix library:

### A. Flexible Sharding in `PeftTrainer`
Tunix's `_shard_optimizer` currently forces sharding constraints that can crash on pre-sharded MaxText states (especially with scalar values).
* **Action:** Modify `tunix/sft/peft_trainer.py` to only apply `with_sharding_constraint` if the optimizer is not already sharded or if a specific flag is set.

### B. Generalized Model Call Interface
Tunix's `get_per_token_logps` hardcodes argument names like `positions` and `attention_mask`.
* **Action:** Update `tunix/rl/common.py` to allow passing a `name_mapping` dictionary. This would allow MaxText to tell Tunix: "Use `decoder_positions` instead of `positions`."

## 5. Cleanup: Deleting Legacy Post-Training Support

As we transition to the Tunix-based "Hybrid Core" architecture, we should remove the legacy, non-Tunix implementations from MaxText to reduce maintenance burden.

### A. Remove Legacy DPO
The existing internal DPO implementation is fragmented and harder to maintain than the Tunix version.
* **Action:** Delete `src/maxtext/trainers/post_train/dpo/dpo_utils.py`.
* **Action:** Remove DPO-specific branches and imports in:
* `src/maxtext/trainers/pre_train/train.py`
* `src/maxtext/utils/train_utils.py`
* `src/MaxText/__init__.py`
* **Action:** Deprecate legacy DPO-specific configuration parameters in `src/maxtext/configs/base.yml` once the Tunix bridge is stable.

## 6. Roadmap for DPO Integration (Immediate Next Steps)
1. **Finalize the `ModelWrapper`:** Fix the "too many values to unpack" error by ensuring the wrapper returns only what Tunix needs (logits).
2. **Formalize the "No-Op" Sharding Override:** Instead of a lambda hack, create a proper `MaxTextDPOTrainer` subclass that overrides `_shard_optimizer` cleanly.
3. **Unified Config:** Allow users to specify `post_training_flavor: tunix_dpo` in `dpo.yml` to automatically trigger these bridge behaviors.
15 changes: 8 additions & 7 deletions src/maxtext/configs/post_train/dpo.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
base_config: "base.yml"

use_dpo: true
train_data_columns: ['chosen', 'rejected']
eval_data_columns: ['chosen', 'rejected']
train_data_columns: ['input', 'chosen', 'rejected']
eval_data_columns: ['input', 'chosen', 'rejected']
base_output_directory: 'gs://maxtext-external/logs'

per_device_batch_size: 2.0
Expand All @@ -12,11 +12,12 @@ eval_interval: 5 # test eval once, in the middle of 10 training steps
eval_steps: 2

# TFDS Pipeline ----------------------
dataset_type: tfds
dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf'
dataset_name: 'tfds:1.0.0'
eval_dataset_name: 'tfds:1.0.0'
eval_split: 'test'
#dataset_type: tfds
#dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf'
#dataset_name: 'tfds:1.0.0'
#eval_dataset_name: 'tfds:1.0.0'
#eval_split: 'test'
packing: False # DEBUG: DO NOT MERGE.

# HF Pipeline -------------------------
hf_eval_split: 'test'
Expand Down
5 changes: 2 additions & 3 deletions src/maxtext/configs/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"maxtext.trainers.pre_train.train_compile": "base.yml",
"maxtext.trainers.post_train.distillation.train_distill": "post_train/distillation.yml",
"maxtext.trainers.post_train.rl.train_rl": "post_train/rl.yml",
"maxtext.trainers.post_train.dpo.train_dpo": "post_train/dpo.yml",
"maxtext.trainers.post_train.sft.train_sft": "post_train/sft.yml",
"maxtext.trainers.post_train.sft.train_sft_deprecated": "post_train/sft.yml",
"maxtext.inference.decode": "base.yml",
Expand Down Expand Up @@ -83,9 +84,7 @@ def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]:
return resolve_config_path(argv[1]), argv[2:]
module = _module_from_path(argv[0])
if module not in _CONFIG_FILE_MAPPING:
raise ValueError(
f"No config file provided and no default config found for module '{module}'"
)
raise ValueError(f"No config file provided and no default config found for module '{module}'")
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
logger.warning("No config file provided, using default config mapping: %s", config_path)
return config_path, argv[1:]
Expand Down
41 changes: 40 additions & 1 deletion src/maxtext/input_pipeline/hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ def preprocessing_pipeline(

pad_id = _get_pad_id(tokenizer)

# Tunix-DPO handles tokenization internally if strings are passed.
# However, MaxText's multihost loader requires numeric JAX arrays.
# We tokenize here and rename columns to match Tunix's TrainingInput requirements.
if tokenize:
dataset = dataset.map(
input_pipeline_utils.tokenization,
Expand Down Expand Up @@ -318,6 +321,41 @@ def lists2array(x):
return jax.tree.map(np.asarray, x, is_leaf=lambda y: isinstance(y, (list, tuple)))

operations.append(grain.MapOperation(lists2array))

# Generate masks and rename keys to match tunix.sft.dpo.dpo_trainer.TrainingInput
class DPOTunixPrep(grain.MapTransform):

def __init__(self, pad_id, max_prompt_length, max_response_length):
self.pad_id = pad_id
self.max_prompt_length = max_prompt_length
self.max_response_length = max_response_length

def _pad(self, x, length, left=False):
x = np.asarray(x)
pad_amount = max(length - x.shape[0], 0)
if left:
pad_width = ((pad_amount, 0),)
else:
pad_width = ((0, pad_amount),)
return np.pad(x[:length], pad_width, constant_values=self.pad_id)

def map(self, x):
prompt_ids = self._pad(x.pop("input"), self.max_prompt_length, left=True)
chosen_ids = self._pad(x.pop("chosen"), self.max_response_length, left=False)
rejected_ids = self._pad(x.pop("rejected"), self.max_response_length, left=False)

x["prompt_ids"] = prompt_ids
x["chosen_ids"] = chosen_ids
x["rejected_ids"] = rejected_ids
x["prompt_mask"] = (prompt_ids != self.pad_id).astype(np.int32)
x["chosen_mask"] = (chosen_ids != self.pad_id).astype(np.int32)
x["rejected_mask"] = (rejected_ids != self.pad_id).astype(np.int32)
return x

# Tunix DPO expects prompt and response to share the total budget.
dpo_max_prompt_len = max_target_length // 2
dpo_max_response_len = max_target_length // 2
operations.append(DPOTunixPrep(pad_id, dpo_max_prompt_len, dpo_max_response_len))
else:
assert len(data_column_names) == 1
operations.append(input_pipeline_utils.HFNormalizeFeatures(data_column_names[0]))
Expand All @@ -337,7 +375,8 @@ def lists2array(x):
)
operations.append(input_pipeline_utils.ReformatPacking(data_column_names))
else:
operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
if not use_dpo:
operations.append(input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id))
operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder))

if shift and not use_dpo:
Expand Down
207 changes: 207 additions & 0 deletions src/maxtext/trainers/post_train/dpo/train_dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DPO Training script that uses Tunix DPOTrainer on a MaxText model.

Example command:
Training & Evaluation:
python3 -m maxtext.trainers.post_train.dpo.train_dpo \
run_name=${WORKLOAD?} base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
tokenizer_path="google/gemma-2-2b-it" tokenizer_type=huggingface \
dataset_type="hf" hf_path="Anthropic/hh-rlhf" \
model_name=${MODEL?} load_parameters_path=${MAXTEXT_CONVERTED_CHECKPOINT?}/0/items \
hf_access_token=${HF_TOKEN?} per_device_batch_size=1 max_target_length=1024 \
eval_interval=2 eval_steps=2 steps=10 profiler=xplane weight_dtype=bfloat16
"""

from absl import app
import jax
import optax
from orbax import checkpoint as ocp
import pathwaysutils

import flax.linen as nn
from flax import nnx
from flax.linen import partitioning as nn_partitioning

from tunix.sft import metrics_logger, profiler
from tunix.sft.dpo.dpo_trainer import DPOTrainer, DPOTrainingConfig

import tunix
from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter
from maxtext.configs import pyconfig
from maxtext.utils import max_utils
from maxtext.common.goodput import (
GoodputEvent,
RECORD_JOB_END_TIME,
RECORD_JOB_START_TIME,
create_goodput_recorder,
maybe_monitor_goodput,
maybe_record_goodput,
record_goodput,
)
from maxtext.optimizers import optimizers
from maxtext.trainers.post_train.sft import hooks
from maxtext.utils import max_logging
from maxtext.utils import maxtext_utils
from maxtext.utils import model_creation_utils


def get_tunix_config(mt_config: pyconfig.HyperParameters) -> DPOTrainingConfig:
"""Gets the Tunix training configurations from the MaxText config.

Args:
mt_config: MaxText config.

Returns:
A Tunix `DPOTrainingConfig` object.
"""
# Checkpointing configurations
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=mt_config.checkpoint_period,
enable_async_checkpointing=mt_config.async_checkpointing,
)

# Metrics configurations
metrics_logging_options = metrics_logger.MetricsLoggerOptions(log_dir=mt_config.tensorboard_dir)

# Profiler configurations
profiler_options = None
if mt_config.profiler:
set_profile_options = True
platform_version = jax.extend.backend.get_backend().platform_version.strip()
if platform_version.startswith("Pathways"):
max_logging.log("Pathways backend detected. Disabling setting profile options.")
set_profile_options = False
profiler_options = profiler.ProfilerOptions(
log_dir=mt_config.tensorboard_dir,
skip_first_n_steps=mt_config.skip_first_n_steps_for_profiler,
profiler_steps=mt_config.profiler_steps,
set_profile_options=set_profile_options,
)

return DPOTrainingConfig(
eval_every_n_steps=mt_config.eval_interval,
max_steps=mt_config.steps,
gradient_accumulation_steps=mt_config.gradient_accumulation_steps,
checkpoint_root_directory=mt_config.checkpoint_dir,
checkpointing_options=checkpointing_options,
metrics_logging_options=metrics_logging_options,
profiler_options=profiler_options,
algorithm="dpo", # TODO: add support of "orpo"
beta=mt_config.dpo_beta,
label_smoothing=mt_config.dpo_label_smoothing,
max_prompt_length=mt_config.max_target_length // 2,
max_response_length=mt_config.max_target_length // 2,
)


def setup_trainer_state(mt_config, goodput_recorder=None):
"""Set up prerequisites for training loop."""
tunix_config = get_tunix_config(mt_config)

with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT):
model, mesh = model_creation_utils.create_nnx_model(mt_config)

# Wrap model with Tunix adapter for consistent interface
model = TunixMaxTextAdapter(model)

learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config)
# pass in model for muon
optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model)

if mt_config.gradient_clipping_threshold > 0:
optimizer = optax.chain(
optax.clip_by_global_norm(max_norm=mt_config.gradient_clipping_threshold),
optimizer,
)

# Pre-shard the optimizer to avoid TypeError in Tunix _shard_optimizer
# Tunix will now detect it's already sharded and skip its internal sharding logic.
with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules):
nnx_optimizer = nnx.Optimizer(model, optimizer, wrt=nnx.Param)
opt_state = nnx.state(nnx_optimizer, nnx.optimizer.OptState)
opt_pspecs = nnx.get_partition_spec(opt_state)
opt_sharded_state = jax.lax.with_sharding_constraint(opt_state, opt_pspecs)
nnx.update(nnx_optimizer, opt_sharded_state)

with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION):
training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)
data_hooks = hooks.SFTDataHooks(mt_config, mesh, goodput_recorder)

tokenizer = tunix.Tokenizer(
tokenizer_type=mt_config.tokenizer_type,
tokenizer_path=mt_config.tokenizer_path,
add_bos=mt_config.add_bos,
add_eos=mt_config.add_eos,
hf_access_token=mt_config.hf_access_token,
)

# Pass the pre-sharded nnx.Optimizer directly to DPOTrainer.
with mesh, nn.logical_axis_rules(mt_config.logical_axis_rules):
trainer = DPOTrainer(
model=model, ref_model=None, optimizer=nnx_optimizer, training_config=tunix_config, tokenizer=None
)
trainer.with_training_hooks(training_hooks)
trainer.with_data_hooks(data_hooks)

return trainer, mesh


def train_model(mt_config: pyconfig.HyperParameters, trainer, mesh):
"""Runs the DPO training loop in Tunix."""
with jax.set_mesh(mesh), mesh, nn.logical_axis_rules(mt_config.logical_axis_rules):
trainer.train(trainer.data_hooks.train_data_iterator, trainer.data_hooks.eval_data_iterator)
return trainer


def train(mt_config, goodput_recorder=None):
"""Main method for DPO training.

Args:
mt_config: MaxText config.
goodput_recorder: An optional GoodputRecorder to record performance metrics.
"""
trainer, mesh = setup_trainer_state(mt_config, goodput_recorder)
_job_completed_gracefully = False
try:
trainer = train_model(mt_config, trainer, mesh)
_job_completed_gracefully = True
finally:
if _job_completed_gracefully:
record_goodput(goodput_recorder, RECORD_JOB_END_TIME)
return trainer, mesh


def main(argv: list[str]) -> None:
"""Main function to run DPO training.

Args:
argv: Command-line arguments.
"""
# import debugpy; debugpy.listen(("localhost", 5678)); print("Attach VS Code now"); debugpy.wait_for_client()

pathwaysutils.initialize()

mt_config = pyconfig.initialize(argv)
max_utils.print_system_information()

goodput_recorder = create_goodput_recorder(mt_config)
record_goodput(goodput_recorder, RECORD_JOB_START_TIME)
with maybe_monitor_goodput(mt_config):
train(mt_config, goodput_recorder)


if __name__ == "__main__":
app.run(main)
Loading
Loading