Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ num_experts_per_tok: 1
megablox: true
sparse_matmul: true
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
ragged_buffer_factor: -1.0 # a factor to determine the size of the ragged buffer for routed MoE activations.
# By default (-1), this buffer will be worst case size to ensure no dropping.
# When set to 1.0 this buffer if set to the size assuming perfectly balanced. If the routing dictates
# a size larger than this then tokens will be dropped.
# In general if ragged_buffer_factor>0, the ragged_buffer_size is is balanced_size * ragged_buffer_factor.
load_balance_loss_weight: 0.0 # weight for the load balance loss
use_random_routing: false # whether to use random routing for debug/test purpose
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ class MoEGeneral(BaseModel):
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.")
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
use_custom_sort_vjp: bool = Field(
True,
Expand Down
11 changes: 9 additions & 2 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,14 @@ def transform_array(input_array, shard_id, strategy, is_batch_sharded):
)
return input_offsets, send_sizes, output_offsets, recv_sizes

def get_ragged_buffer_size(self, local_expert_size, local_batch):
if self.config.ragged_buffer_factor > 0.0:
balanced_size = local_batch
return int(balanced_size * self.config.ragged_buffer_factor)
else:
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
return int(local_batch * max_local_experts_per_tok)

def transform_bias(self, experts_index, *biases):
"""Selects bias values for a variable number of bias tensors based on chosen experts."""
return tuple(bias[experts_index] for bias in biases)
Expand Down Expand Up @@ -1180,8 +1188,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r
# In the worst case, all of the global input data is assigned to each expert in the current shard.
# This would result in num_expert_shards * input_size * experts_per_shard assignments. However, if
# experts_per_shard > num_experts_per_tok we cannot assign more than num_experts_per_tok to all of the inputs.
max_local_experts_per_tok = min(local_expert_size, self.config.num_experts_per_tok)
buffer_size = int(num_expert_parallelism * batch_size * sequence_length * max_local_experts_per_tok)
buffer_size = self.get_ragged_buffer_size(local_expert_size, jnp.shape(x)[0])
output_shape = jax.lax.empty((buffer_size, self.config.emb_dim), dtype=x.dtype)

x = jax.lax.ragged_all_to_all(
Expand Down
Loading