diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 453c1650..9530dd09 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -979,7 +979,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -993,7 +993,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -1007,7 +1007,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -1021,7 +1021,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("heads",), + ("embed",), ), ) @@ -1332,11 +1332,13 @@ def setup(self): precision=self.precision, ) + proj_attn_kernel_axes = ("heads", "embed") + self.proj_attn = nn.Dense( self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes), use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), dtype=self.dtype, param_dtype=self.weights_dtype, name="i_proj", @@ -1345,9 +1347,9 @@ def setup(self): self.encoder_proj_attn = nn.Dense( self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes), use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), dtype=self.dtype, param_dtype=self.weights_dtype, name="e_proj", diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 11d7cad2..2a614882 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -193,11 +193,11 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - "mlp", "embed", + "mlp", ), ), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -248,9 +248,9 @@ def __init__( precision=precision, kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), - ( + ( + "mlp", "embed", - "mlp", ), ), )