Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ load_balance_loss_weight: 0.0 # weight for the load balance loss
use_random_routing: false # whether to use random routing for debug/test purpose
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
use_gather_mosaic_kernel: false # whether to use a custom mosaic kernel for token gather ops
# tunable tiling dimensions used for mlp gmm
# megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
# tokamax ragged dot - supports all 18 configs
Expand Down
17 changes: 9 additions & 8 deletions src/maxtext/configs/models/deepseek3-671b-batchsplit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for DeepSeek V3 - 671B that uses fsdp on two logical axes
# model config for DeepSeek V3 - 671B that uses batch split schedule

# For DeepSeek default device-limited routing,
# please set n_routing_groups=8 and topk_routing_group=4 in your command-line arguments.
Expand Down Expand Up @@ -55,17 +55,18 @@ rope_interleave: True
rope_truncate: True
rope_attention_scaling: False

use_batch_split_schedule: True
shard_mode: "explicit"
override_logical_axis_rules: True
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']]
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
['activation_batch_moe', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
['activation_kv_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert', 'context']],
['activation_embed_and_logits_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_norm_length', ['context']],
['activation_norm_length_moe', ['context']],
['activation_norm_length', []],
['activation_norm_length_moe', []],
['activation_heads', []],
['activation_stage', 'stage'],
['embed', ['fsdp']],
Expand All @@ -81,8 +82,8 @@ logical_axis_rules: [
['kv_heads', ['fsdp_transpose']],
['heads', ['fsdp_transpose']],
['mlp', ['fsdp_transpose']],
['mlp_only_fsdp_transpose', ['fsdp_transpose']],
['expert_only', ['expert']],
['fsdp_transpose_only', ['fsdp_transpose']],
['fsdp_transpose_and_expert', ['fsdp_transpose', 'expert']],
['fsdp_transpose_only', ['fsdp_transpose']],
['expert_only', ['expert']],
['diloco', 'diloco'],
]
6 changes: 4 additions & 2 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,10 @@ class MoEGeneral(BaseModel):
False,
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
)
use_gather_mosaic_kernel: bool = Field(
False,
description="Whether to use a custom mosaic kernel for token gather ops.",
)
use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.")
interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.")
expert_shard_attention_option: Literal["fsdp", "context"] = Field(
Expand Down Expand Up @@ -2597,8 +2601,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.use_grpo = False

if self.use_batch_split_schedule:
if not (self.decoder_block == DecoderBlockType.DEEPSEEK and self.sparse_matmul and self.use_tokamax_gmm):
raise ValueError("Batch split only supports deepseek, with `sparse_matmul=True` and `use_tokamax_gmm=True`")
if self.quantization and not (self.use_qwix_quantization and self.quantization == "fp8_full"):
raise ValueError(
"Batch split quantization only supports `use_qwix_quantization=True` and `quantization=fp8_full`"
Expand Down
17 changes: 17 additions & 0 deletions src/maxtext/kernels/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2023–2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Attention kernels."""

from maxtext.kernels.attention import splash_attention_kernel
144 changes: 136 additions & 8 deletions src/maxtext/kernels/attention/splash_attention_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,7 @@ def _wrapped(
def reshape_activations(activations):
if activations.ndim == 4: # pytype: disable=attribute-error
kv_heads, q_heads_per_kv_head, q_seq_len, head_dim = activations.shape # pytype: disable=attribute-error
return activations.reshape(
kv_heads * q_heads_per_kv_head, q_seq_len, head_dim
) # pytype: disable=attribute-error
return activations.reshape(kv_heads * q_heads_per_kv_head, q_seq_len, head_dim) # pytype: disable=attribute-error
return activations

def reshape_residuals(residuals):
Expand Down Expand Up @@ -1166,10 +1164,7 @@ def _splash_attention_fwd(
mask_function: MaskFunctionType | None,
attn_logits_soft_cap: float | None = None,
interpret: bool = False,
) -> tuple[
tuple[jax.Array],
SplashResidualsType,
]:
) -> tuple[tuple[jax.Array], SplashResidualsType,]:
"""Forward pass for splash attention."""
if save_residuals:
raise NotImplementedError("Higher-order AD not supported")
Expand Down Expand Up @@ -1606,7 +1601,6 @@ def init():
)

def body(i, _):

slice_k = pl.ds(i * bkv_compute, bkv_compute)
q = q_ref[...] # We keep q potentially transposed, since it's always RHS

Expand Down Expand Up @@ -2238,6 +2232,120 @@ def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None):
)


@partial(
jax.jit,
static_argnames=[
"is_mqa",
"block_sizes",
"save_residuals",
"mask_value",
"attn_logits_soft_cap",
"residual_checkpoint_name",
"mask_function",
"interpret",
],
)
def _splash_attention_manual_fwd(
fwd_mask_info: mask_info_lib.MaskInfo,
dq_mask_info: mask_info_lib.MaskInfo | None,
dkv_mask_info: mask_info_lib.MaskInfo | None,
q: jax.Array,
k: jax.Array,
v: jax.Array,
segment_ids: SegmentIds | None = None,
sinks: jax.Array | None = None,
*,
is_mqa: bool,
block_sizes: BlockSizes | None,
save_residuals: bool,
mask_value: float,
attn_logits_soft_cap: float | None,
residual_checkpoint_name: str | None,
mask_function: MaskFunctionType | None,
interpret: bool,
) -> SplashCustomReturnType:
def _collapse_partial_mask_blocks(mask_info: mask_info_lib.MaskInfo | None):
if mask_info is None or mask_info.partial_mask_blocks is None:
return mask_info

return mask_info._replace(
partial_mask_blocks=mask_info.partial_mask_blocks.reshape(-1, *mask_info.partial_mask_blocks.shape[-2:])
)

if not save_residuals:
raise ValueError("Expected save_residuals to be `True`.")

fwd_mask_info = _collapse_partial_mask_blocks(fwd_mask_info)
dq_mask_info = _collapse_partial_mask_blocks(dq_mask_info)
dkv_mask_info = _collapse_partial_mask_blocks(dkv_mask_info)
del dq_mask_info, dkv_mask_info

out, (logsumexp,) = _splash_attention_forward( # pytype: disable=wrong-arg-types
fwd_mask_info,
q,
k,
v,
segment_ids,
mask_value=mask_value,
is_mqa=is_mqa,
block_sizes=block_sizes,
residual_checkpoint_name=residual_checkpoint_name,
save_residuals=True,
mask_function=mask_function,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=interpret,
)
return out, logsumexp


def _splash_attention_manual_bwd(
fwd_mask_info: mask_info_lib.MaskInfo,
dq_mask_info: mask_info_lib.MaskInfo | None,
dkv_mask_info: mask_info_lib.MaskInfo | None,
q: jax.Array,
k: jax.Array,
v: jax.Array,
out: jax.Array,
logsumexp: jax.Array,
do: jax.Array,
segment_ids: SegmentIds | None = None,
sinks: jax.Array | None = None,
*,
is_mqa: bool,
block_sizes: BlockSizes | None,
save_residuals: bool,
mask_value: float,
attn_logits_soft_cap: float | None,
residual_checkpoint_name: str | None,
mask_function: MaskFunctionType | None,
interpret: bool,
):
del fwd_mask_info
res = (
q,
k,
v,
segment_ids,
out,
logsumexp,
dq_mask_info,
dkv_mask_info,
)
_, _, _, dq, dk, dv, _ = _splash_attention_bwd(
save_residuals=save_residuals,
mask_value=mask_value,
is_mqa=is_mqa,
block_sizes=block_sizes,
residual_checkpoint_name=residual_checkpoint_name,
mask_function=mask_function,
attn_logits_soft_cap=attn_logits_soft_cap,
interpret=interpret,
res=res,
do=do,
)
return dq, dk, dv


@jax.tree_util.register_pytree_node_class
class SplashAttentionKernel:
"""Defines a SplashAttention kernel object."""
Expand All @@ -2264,6 +2372,26 @@ def __call__(self, *args, **kwargs) -> SplashCustomReturnType:
**self.kwargs,
)

def manual_fwd(self, *args, **kwargs) -> SplashCustomReturnType:
return _splash_attention_manual_fwd(
self.fwd_mask_info,
self.dq_mask_info,
self.dkv_mask_info,
*args,
**kwargs,
**self.kwargs,
)

def manual_bwd(self, *args, **kwargs):
return _splash_attention_manual_bwd(
self.fwd_mask_info,
self.dq_mask_info,
self.dkv_mask_info,
*args,
**kwargs,
**self.kwargs,
)

def manual_sharding_spec(self, sharding: jax.sharding.NamedSharding):
"""Returns a value that can be used as a shard_map partition spec for the kernel."""
if self.fwd_mask_info.data_next is not None:
Expand Down
7 changes: 1 addition & 6 deletions src/maxtext/kernels/sort_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ def _route_impl(
assert (
tokens.shape[0] == selected_experts.shape[0] and selected_experts.ndim == 2
), f"{tokens.shape=}, {selected_experts.shape=}"
if use_custom_mosaic_kernel:
raise NotImplementedError("Custom Mosaic kernel not implemented.")
inds = jnp.argsort(jnp.ravel(selected_experts)) // selected_experts.shape[1]
return _sort_impl(tokens, inds, use_custom_mosaic_kernel)

Expand All @@ -114,7 +112,4 @@ def _unroute_impl(


def _sort_impl(tokens: jax.Array, inds: jax.Array, use_custom_mosaic_kernel: bool) -> jax.Array:
if use_custom_mosaic_kernel:
raise NotImplementedError("Custom Mosaic kernel not implemented.")
else:
return tokens[inds, ...]
return tokens[inds, ...]
5 changes: 1 addition & 4 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,12 +919,9 @@ def __call__(
y,
self.variables["params"]["moe_layers"],
decoder_positions,
decoder_segment_ids,
model_mode=model_mode,
mesh=mesh,
quant=self.quant,
cfg=cfg,
policy=policy,
num_layers=num_moe_layers,
)
else:
y, _ = self.scan_decoder_layers(
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:
"activation_embed",
)
)
out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh)
out_pspec = logical_to_mesh_axes(output_axis_names, self.mesh, rules=self.config.logical_axis_rules)

out_sharding = NamedSharding(self.mesh, out_pspec) if self.config.shard_mode == ShardMode.EXPLICIT else None

Expand Down
7 changes: 1 addition & 6 deletions src/maxtext/layers/nnx_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,20 +976,15 @@ def __call__(
num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers

if cfg.use_batch_split_schedule:
policy = self.get_remat_policy()

mock_params = self._build_linen_params(self.moe_layer)

y = deepseek_batchsplit.scan_batch_split_layers(
y,
mock_params,
decoder_positions,
decoder_segment_ids,
model_mode=model_mode,
mesh=self.mesh,
quant=self.quant,
cfg=cfg,
policy=policy,
num_layers=num_moe,
)
else:
y, self.moe_layer = self._apply_layers_sequentially(
Expand Down
Loading
Loading