From 33164c6771005f2d7662659f9eeed7c00d129abb Mon Sep 17 00:00:00 2001 From: csgoogle Date: Mon, 6 Apr 2026 15:39:16 +0530 Subject: [PATCH 1/4] Change bias initialization from 'embed' to 'heads' Fix the bias sharding axis, it should be output axis instead of input one. --- src/maxdiffusion/models/attention_flax.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 453c1650..8b17dbbb 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -979,7 +979,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -993,7 +993,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -1007,7 +1007,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -1021,7 +1021,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("heads",), + ("embed",), ), ) From bbdf2fe514262e38a3a193c9cedbebd19e75caec Mon Sep 17 00:00:00 2001 From: csgoogle Date: Mon, 6 Apr 2026 16:13:00 +0530 Subject: [PATCH 2/4] Update kernel and bias initialization axes in attention layer Fix in FlaxAttention --- src/maxdiffusion/models/attention_flax.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 8b17dbbb..9530dd09 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -1332,11 +1332,13 @@ def setup(self): precision=self.precision, ) + proj_attn_kernel_axes = ("heads", "embed") + self.proj_attn = nn.Dense( self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes), use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), dtype=self.dtype, param_dtype=self.weights_dtype, name="i_proj", @@ -1345,9 +1347,9 @@ def setup(self): self.encoder_proj_attn = nn.Dense( self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes), use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), dtype=self.dtype, param_dtype=self.weights_dtype, name="e_proj", From d1d43a4ef7ab123d4fc02766b5974db455b40579 Mon Sep 17 00:00:00 2001 From: csgoogle Date: Mon, 6 Apr 2026 16:15:17 +0530 Subject: [PATCH 3/4] Reorder kernel initialization parameters Fix for ApproximateGelu and WanFeedForward too --- src/maxdiffusion/models/wan/transformers/transformer_wan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 11d7cad2..e217028f 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -193,8 +193,8 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - "mlp", "embed", + "mlp", ), ), bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), @@ -248,9 +248,9 @@ def __init__( precision=precision, kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), - ( + ( + "mlp", "embed", - "mlp", ), ), ) From 540375b1745e01637796bf3493e8a467f5a78e6b Mon Sep 17 00:00:00 2001 From: csgoogle Date: Mon, 6 Apr 2026 16:21:48 +0530 Subject: [PATCH 4/4] Change bias_init partitioning from 'embed' to 'mlp' fix gelu bias too --- src/maxdiffusion/models/wan/transformers/transformer_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index e217028f..2a614882 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -197,7 +197,7 @@ def __init__( "mlp", ), ), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) def __call__(self, x: jax.Array) -> jax.Array: