diff --git a/.gitignore b/.gitignore index bd4a64b8..1746dd2c 100644 --- a/.gitignore +++ b/.gitignore @@ -168,6 +168,7 @@ tags .DS_Store # RL pipelines may produce mp4 outputs *.mp4 +!assets/wan_animate/**/*.mp4 # dependencies /transformers @@ -177,6 +178,20 @@ tags wandb -# Gemini CLI +# Local assistant tooling +.codex +.codex/ +.claude +.claude/ +.gemini .gemini/ +Gemini.md +gemini.md + gha-creds-*.json + +#jax cache +.jax_cache/ +.mplconfig/ + +.tpu_logs/ diff --git a/assets/wan_animate/src_face.mp4 b/assets/wan_animate/src_face.mp4 new file mode 100644 index 00000000..e902be96 Binary files /dev/null and b/assets/wan_animate/src_face.mp4 differ diff --git a/assets/wan_animate/src_pose.mp4 b/assets/wan_animate/src_pose.mp4 new file mode 100644 index 00000000..976e5071 Binary files /dev/null and b/assets/wan_animate/src_pose.mp4 differ diff --git a/assets/wan_animate/src_ref.png b/assets/wan_animate/src_ref.png new file mode 100644 index 00000000..9167a9e9 Binary files /dev/null and b/assets/wan_animate/src_ref.png differ diff --git a/src/maxdiffusion/configs/base_wan_animate_27b.yml b/src/maxdiffusion/configs/base_wan_animate_27b.yml new file mode 100644 index 00000000..32e3ea17 --- /dev/null +++ b/src/maxdiffusion/configs/base_wan_animate_27b.yml @@ -0,0 +1,382 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This sentinel is a reminder to choose a real run name. +run_name: '' + +metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. +# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ +write_metrics: True + +timing_metrics_file: "" # for testing, local file that stores function timing metrics such as state creation, compilation. If empty, no metrics are written. +write_timing_metrics: True + +gcs_metrics: False +# If true save config to GCS in {base_output_directory}/{run_name}/ +save_config_to_gcs: False +log_period: 100 + +pretrained_model_name_or_path: 'Wan-AI/Wan2.2-Animate-14B-Diffusers' +model_name: wan2.2 +model_type: 'I2V' + +# Overrides the transformer from pretrained_model_name_or_path +wan_transformer_pretrained_model_name_or_path: '' + +unet_checkpoint: '' +revision: '' +# This will convert the weights to this dtype. +# When running inference on TPUv5e, use weights_dtype: 'bfloat16' +weights_dtype: 'bfloat16' +# This sets the layer's dtype in the model. Ex: nn.Dense(dtype=activations_dtype) +activations_dtype: 'bfloat16' + +# Replicates vae across devices instead of using the model's sharding annotations for sharding. +replicate_vae: False +# Number of devices to shard VAE spatial activations across. -1 uses all devices. +vae_spatial: -1 + +# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision +# Options are "DEFAULT", "HIGH", "HIGHEST" +# fp32 activations and fp32 weights with HIGHEST will provide the best precision +# at the cost of time. +precision: "DEFAULT" +# Use jax.lax.scan for transformer layers +scan_layers: False + +# if False state is not jitted and instead replicate is called. This is good for debugging on single host +# It must be True for multi-host. +jit_initializers: True + +# Set true to load weights from pytorch +from_pt: True +split_head_dim: True +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring +flash_min_seq_length: 4096 +# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. +# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. +# However, when padding tokens are significant, this will lead to worse quality and should be set to True. +mask_padding_tokens: True +# Maxdiffusion has 2 types of attention sharding strategies: +# 1. attention_sharding_uniform = True : same sequence sharding rules applied for q in both (self and cross attention) +# 2. attention_sharding_uniform = False : Heads are sharded uniformly across devices for self attention while sequence is sharded +# in cross attention q. +attention_sharding_uniform: True +dropout: 0.1 + +flash_block_sizes: { + "block_q" : 512, + "block_kv_compute" : 512, + "block_kv" : 512, + "block_q_dkv" : 512, + "block_kv_dkv" : 512, + "block_kv_dkv_compute" : 512, + "block_q_dq" : 512, + "block_kv_dq" : 512, + "use_fused_bwd_kernel": False, +} +# Use on v6e +# flash_block_sizes: { +# "block_q" : 3024, +# "block_kv_compute" : 1024, +# "block_kv" : 2048, +# "block_q_dkv" : 3024, +# "block_kv_dkv" : 2048, +# "block_kv_dkv_compute" : 2048, +# "block_q_dq" : 3024, +# "block_kv_dq" : 2048 +# "use_fused_bwd_kernel": False, +# } +# GroupNorm groups +norm_num_groups: 32 + +# train text_encoder - Currently not supported for SDXL +train_text_encoder: False +text_encoder_learning_rate: 4.25e-6 + +# https://arxiv.org/pdf/2305.08891.pdf +snr_gamma: -1.0 + +timestep_bias: { + # a value of later will increase the frequence of the model's final training steps. + # none, earlier, later, range + strategy: "none", + # multiplier for bias, a value of 2.0 will double the weight of the bias, 0.5 will halve it. + multiplier: 1.0, + # when using strategy=range, the beginning (inclusive) timestep to bias. + begin: 0, + # when using strategy=range, the final step (inclusive) to bias. + end: 1000, + # portion of timesteps to bias. + # 0.5 will bias one half of the timesteps. Value of strategy determines + # whether the biased portions are in the earlier or later timesteps. + portion: 0.25 +} + +# Override parameters from checkpoints's scheduler. +diffusion_scheduler_config: { + _class_name: 'FlaxEulerDiscreteScheduler', + prediction_type: 'epsilon', + rescale_zero_terminal_snr: False, + timestep_spacing: 'trailing' +} + +# Output directory +# Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" +base_output_directory: "" + +# Hardware +hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False + +# Parallelism +mesh_axes: ['data', 'fsdp', 'context', 'tensor'] + +# batch : batch dimension of data and activations +# hidden : +# embed : attention qkv dense layer hidden dim named as embed +# heads : attention head dim = num_heads * head_dim +# length : attention sequence length +# temb_in : dense.shape[0] of resnet dense before conv +# out_c : dense.shape[1] of resnet dense before conv +# out_channels : conv.shape[-1] activation +# keep_1 : conv.shape[0] weight +# keep_2 : conv.shape[1] weight +# conv_in : conv.shape[2] weight +# conv_out : conv.shape[-1] weight +logical_axis_rules: [ + ['batch', ['data', 'fsdp']], + ['activation_batch', ['data', 'fsdp']], + ['activation_self_attn_heads', ['context', 'tensor']], + ['activation_cross_attn_q_length', ['context', 'tensor']], + ['activation_length', 'context'], + ['activation_heads', 'tensor'], + ['mlp','tensor'], + ['embed', ['context', 'fsdp']], + ['heads', 'tensor'], + ['norm', 'tensor'], + ['conv_batch', ['data', 'context', 'fsdp']], + ['out_channels', 'tensor'], + ['conv_out', 'context'], + ] +data_sharding: [['data', 'fsdp', 'context', 'tensor']] + +# One axis for each parallelism type may hold a placeholder (-1) +# value to auto-shard based on available slices and devices. +# By default, product of the DCN axes should equal number of slices +# and product of the ICI axes should equal number of devices per slice. +dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded +dcn_fsdp_parallelism: 1 +dcn_context_parallelism: -1 +dcn_tensor_parallelism: 1 +ici_data_parallelism: 1 +ici_fsdp_parallelism: 1 +ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded +ici_tensor_parallelism: 1 + +allow_split_physical_axes: False + +# Dataset +# Replace with dataset path or train_data_dir. One has to be set. +dataset_name: 'diffusers/pokemon-gpt4-captions' +train_split: 'train' +dataset_type: 'tfrecord' +cache_latents_text_encoder_outputs: True +# cache_latents_text_encoder_outputs only apply to dataset_type="tf", +# only apply to small dataset that fits in memory +# prepare image latents and text encoder outputs +# Reduce memory consumption and reduce step time during training +# transformed dataset is saved at dataset_save_location +dataset_save_location: '' +load_tfrecord_cached: True +train_data_dir: '' +dataset_config_name: '' +jax_cache_dir: '.jax_cache' +hf_data_dir: '' +hf_train_files: '' +hf_access_token: '' +image_column: 'image' +caption_column: 'text' +resolution: 1024 +center_crop: False +random_flip: False +# If cache_latents_text_encoder_outputs is True +# the num_proc is set to 1 +tokenize_captions_num_proc: 4 +transform_images_num_proc: 4 +reuse_example_batch: False +enable_data_shuffling: True + +# Defines the type of gradient checkpoint to enable. +# NONE - means no gradient checkpoint +# FULL - means full gradient checkpoint, whenever possible (minimum memory usage) +# MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation, +# except for ones that involve batch dimension - that means that all attention and projection +# layers will have gradient checkpoint, but not the backward with respect to the parameters. +# OFFLOAD_MATMUL_WITHOUT_BATCH - same as MATMUL_WITHOUT_BATCH but offload instead of recomputing. +# CUSTOM - set names to offload and save. +remat_policy: "NONE" +# For CUSTOM policy set below, current annotations are for: attn_output, query_proj, key_proj, value_proj +# xq_out, xk_out, ffn_activation +names_which_can_be_saved: [] +names_which_can_be_offloaded: [] + +# checkpoint every number of samples, -1 means don't checkpoint. +checkpoint_every: -1 +checkpoint_dir: "" +# enables one replica to read the ckpt then broadcast to the rest +enable_single_replica_ckpt_restoring: False + +# Training loop +learning_rate: 1.e-5 +scale_lr: False +max_train_samples: -1 +# max_train_steps takes priority over num_train_epochs. +max_train_steps: 1500 +num_train_epochs: 1 +seed: 0 +output_dir: 'sdxl-model-finetuned' +per_device_batch_size: 0.125 +# If global_batch_size % jax.device_count is not 0, use FSDP sharding. +global_batch_size: 0 + +# For creating tfrecords from dataset +tfrecords_dir: '' +no_records_per_shard: 0 +enable_eval_timesteps: False +timesteps_list: [125, 250, 375, 500, 625, 750, 875] +num_eval_samples: 420 + +warmup_steps_fraction: 0.1 +learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps. +save_optimizer: False + +# However you may choose a longer schedule (learning_rate_schedule_steps > steps), in which case the training will end before +# dropping fully down. Or you may choose a shorter schedule, where the unspecified steps will have a learning rate of 0. + +# AdamW optimizer parameters +adam_b1: 0.9 # Exponential decay rate to track the first moment of past gradients. +adam_b2: 0.999 # Exponential decay rate to track the second moment of past gradients. +adam_eps: 1.e-8 # A small constant applied to denominator outside of the square root. +adam_weight_decay: 0 # AdamW Weight decay +max_grad_norm: 1.0 + +enable_profiler: False +# Skip first n steps for profiling, to omit things like compilation and to give +# the iteration time a chance to stabilize. +skip_first_n_steps_for_profiler: 5 +profiler_steps: 10 + +# Enable JAX named scopes for detailed profiling and debugging +# When enabled, adds named scopes around key operations in transformer and attention layers +enable_jax_named_scopes: False + +# Generation parameters +prompt: "The person from the reference image follows the motion from the driving videos with natural body movement, stable identity, expressive face, cinematic framing, and realistic lighting." +prompt_2: "A clean, high-quality character animation of the reference subject matching the pose and facial performance from the conditioning videos, with smooth motion and detailed textures." +negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" +do_classifier_free_guidance: True +height: 720 +width: 1280 +num_frames: 121 +flow_shift: 5.0 + +# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py +# guidance scale factor for low noise transformer +guidance_scale_low: 3.0 + +# guidance scale factor for high noise transformer +guidance_scale_high: 4.0 + +# The timestep threshold. If `t` is at or above this value, +# the `high_noise_model` is considered as the required model. +# timestep to switch between low noise and high noise transformer +boundary_ratio: 0.875 + +# Diffusion CFG cache (FasterCache-style) +use_cfg_cache: False +# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass +# when predicted output change (based on accumulated latent/timestep drift) is small +use_sen_cache: False + +# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf +guidance_rescale: 0.0 +num_inference_steps: 30 +fps: 24 +save_final_checkpoint: False + +# Wan Animate local-input overrides (used by generate_wan_animate.py) +# These sample assets live under assets/wan_animate/. +mode: "animate" +reference_image_path: "assets/wan_animate/src_ref.png" +pose_video_path: "assets/wan_animate/src_pose.mp4" +face_video_path: "assets/wan_animate/src_face.mp4" +background_video_path: "" +mask_video_path: "" +segment_frame_length: 77 +prev_segment_conditioning_frames: 1 +motion_encode_batch_size: null +animate_guidance_scale: 1.0 + +# SDXL Lightning parameters +lightning_from_pt: True +# Empty or "ByteDance/SDXL-Lightning" to enable lightning. +lightning_repo: "" +# Empty or "sdxl_lightning_4step_unet.safetensors" to enable lightning. +lightning_ckpt: "" + +# LoRA parameters +enable_lora: False +# Values are lists to support multiple LoRA loading during inference in the future. +lora_config: { + rank: [64], + lora_model_name_or_path: ["lightx2v/Wan2.2-Distill-Loras"], + high_noise_weight_name: ["wan2.2_t2v_A14b_high_noise_lora_rank64_lightx2v_4step_1217.safetensors"], + low_noise_weight_name: ["wan2.2_t2v_A14b_low_noise_lora_rank64_lightx2v_4step_1217.safetensors"], + adapter_name: ["wan22-distill-lora"], + scale: [1.0], + from_pt: [] +} +# Ex with values: +# lora_config : { +# lora_model_name_or_path: ["ByteDance/Hyper-SD"], +# weight_name: ["Hyper-SDXL-2steps-lora.safetensors"], +# adapter_name: ["hyper-sdxl"], +# scale: [0.7], +# from_pt: [True] +# } + +enable_mllog: False + +#controlnet +controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' +controlnet_from_pt: True +controlnet_conditioning_scale: 0.5 +controlnet_image: 'https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png' +quantization: '' +# Shard the range finding operation for quantization. By default this is set to number of slices. +quantization_local_shard_count: -1 +compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. +use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. +# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +quantization_calibration_method: "absmax" +qwix_module_path: ".*" + +# Eval model on per eval_every steps. -1 means don't eval. +eval_every: -1 +eval_data_dir: "" +enable_generate_video_for_eval: False # This will increase the used TPU memory. +eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list). + +enable_ssim: False diff --git a/src/maxdiffusion/generate_wan_animate.py b/src/maxdiffusion/generate_wan_animate.py new file mode 100644 index 00000000..9ebeaac9 --- /dev/null +++ b/src/maxdiffusion/generate_wan_animate.py @@ -0,0 +1,233 @@ +# Copyright 2026 Google LLC +# Licensed under the Apache License, Version 2.0 (the "License"); + +import jax +import os +import time +from absl import app +from maxdiffusion import pyconfig, max_logging, max_utils +from maxdiffusion.train_utils import transformer_engine_context +from maxdiffusion.utils import export_to_video +from maxdiffusion.utils.loading_utils import load_image, load_video +import flax +from maxdiffusion.pipelines.wan.wan_pipeline_animate import WanAnimatePipeline +import numpy as np +from PIL import Image + +jax.config.update("jax_use_shardy_partitioner", True) + + +def _get_animate_inference_settings(config): + """Resolve animate-specific inference settings with upstream defaults.""" + return { + "segment_frame_length": getattr(config, "segment_frame_length", 77), + "prev_segment_conditioning_frames": getattr(config, "prev_segment_conditioning_frames", 1), + "motion_encode_batch_size": getattr(config, "motion_encode_batch_size", None), + "guidance_scale": getattr(config, "animate_guidance_scale", 1.0), + } + + +def _frame_summary(name, frames): + """Return a compact frame-count/size summary for logging.""" + if not frames: + return f"{name}_frames=0" + return f"{name}_frames={len(frames)}, {name}_frame_size={getattr(frames[0], 'size', None)}" + + +def run(config): + writer = max_utils.initialize_summary_writer(config) + if jax.process_index() == 0 and writer: + max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") + + load_start = time.perf_counter() + pipeline = WanAnimatePipeline.from_pretrained(config) + load_time = time.perf_counter() - load_start + max_logging.log(f"load_time: {load_time:.1f}s") + + # Setup inputs + reference_image_path = getattr(config, "reference_image_path", "") + if reference_image_path: + image = load_image(reference_image_path) + reference_image_source = reference_image_path + else: + image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + image = load_image(image_url) + reference_image_source = image_url + + mode = getattr(config, "mode", "animate") + pose_video_path = getattr(config, "pose_video_path", "") + face_video_path = getattr(config, "face_video_path", "") + background_video_path = getattr(config, "background_video_path", "") + mask_video_path = getattr(config, "mask_video_path", "") + + num_frames = config.num_frames + height = config.height + width = config.width + + # face_video needs to match motion_encoder_size (probably 224x224 or 256x256) + motion_encoder_size = pipeline.transformer.config.motion_encoder_size + + if pose_video_path and face_video_path: + max_logging.log( + f"Loading preprocessed videos from disk. pose_video={pose_video_path}, face_video={face_video_path}" + ) + pose_video = load_video(pose_video_path) + face_video = load_video(face_video_path) + num_frames = min(num_frames, len(pose_video), len(face_video)) + if num_frames == 0: + raise ValueError("Loaded empty pose/face video. Check preprocessing outputs.") + pose_video = pose_video[:num_frames] + face_video = face_video[:num_frames] + else: + # Fallback path used for quick smoke tests only. + max_logging.log( + "No pose/face video paths provided; generating dummy videos for a smoke test only. " + "For real outputs provide preprocessed pose_video_path and face_video_path." + ) + pose_video = [Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8)) for _ in range(num_frames)] + face_video = [Image.fromarray(np.zeros((motion_encoder_size, motion_encoder_size, 3), dtype=np.uint8)) for _ in range(num_frames)] + + background_video = None + mask_video = None + if mode == "replace": + if not background_video_path or not mask_video_path: + raise ValueError("Replace mode requires both `background_video_path` and `mask_video_path`.") + background_video = load_video(background_video_path)[:num_frames] + mask_video = load_video(mask_video_path)[:num_frames] + + max_logging.log( + "Wan animate inputs: reference_image=%s, image_size=%s, pose_video_path=%s, face_video_path=%s, %s, %s" + % ( + reference_image_source, + getattr(image, "size", None), + pose_video_path or "", + face_video_path or "", + _frame_summary("pose", pose_video), + _frame_summary("face", face_video), + ) + ) + if mode == "replace": + max_logging.log( + "Wan replace inputs: background_video_path=%s, mask_video_path=%s, %s, %s" + % ( + background_video_path, + mask_video_path, + _frame_summary("background", background_video), + _frame_summary("mask", mask_video), + ) + ) + + animate_settings = _get_animate_inference_settings(config) + prompt = config.prompt + negative_prompt = config.negative_prompt if animate_settings["guidance_scale"] > 1.0 else None + + max_logging.log( + "Num steps: %s, height: %s, width: %s, frames: %s, segment_frame_length: %s, " + "prev_segment_conditioning_frames: %s, guidance_scale: %s" + % ( + config.num_inference_steps, + height, + width, + num_frames, + animate_settings["segment_frame_length"], + animate_settings["prev_segment_conditioning_frames"], + animate_settings["guidance_scale"], + ) + ) + + s0 = time.perf_counter() + + # First pass (compile) + videos = pipeline( + image=image, + pose_video=pose_video, + face_video=face_video, + background_video=background_video, + mask_video=mask_video, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + segment_frame_length=animate_settings["segment_frame_length"], + prev_segment_conditioning_frames=animate_settings["prev_segment_conditioning_frames"], + motion_encode_batch_size=animate_settings["motion_encode_batch_size"], + guidance_scale=animate_settings["guidance_scale"], + num_inference_steps=config.num_inference_steps, + mode=mode, + ) + + compile_time = time.perf_counter() - s0 + max_logging.log(f"compile_time: {compile_time}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/compile_time", compile_time, global_step=0) + + s0 = time.perf_counter() + videos = pipeline( + image=image, + pose_video=pose_video, + face_video=face_video, + background_video=background_video, + mask_video=mask_video, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + segment_frame_length=animate_settings["segment_frame_length"], + prev_segment_conditioning_frames=animate_settings["prev_segment_conditioning_frames"], + motion_encode_batch_size=animate_settings["motion_encode_batch_size"], + guidance_scale=animate_settings["guidance_scale"], + num_inference_steps=config.num_inference_steps, + mode=mode, + ) + + generation_time = time.perf_counter() - s0 + max_logging.log(f"generation_time: {generation_time}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time", generation_time, global_step=0) + + filename_prefix = "animate_" + os.makedirs(config.output_dir, exist_ok=True) + for i in range(len(videos)): + video_path = os.path.join(config.output_dir, f"{filename_prefix}wan_output_{config.seed}_{i}.mp4") + export_to_video(videos[i], video_path, fps=config.fps) + max_logging.log(f"Saved video to {video_path}") + + if getattr(config, "enable_profiler", False): + s0 = time.perf_counter() + max_utils.activate_profiler(config) + _ = pipeline( + image=image, + pose_video=pose_video, + face_video=face_video, + background_video=background_video, + mask_video=mask_video, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + segment_frame_length=animate_settings["segment_frame_length"], + prev_segment_conditioning_frames=animate_settings["prev_segment_conditioning_frames"], + motion_encode_batch_size=animate_settings["motion_encode_batch_size"], + guidance_scale=animate_settings["guidance_scale"], + num_inference_steps=config.num_inference_steps, + mode=mode, + ) + max_utils.deactivate_profiler(config) + generation_time_with_profiler = time.perf_counter() - s0 + max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0) + + return videos + +def main(argv) -> None: + pyconfig.initialize(argv) + try: + flax.config.update("flax_always_shard_variable", False) + except LookupError: + pass + run(pyconfig.config) + +if __name__ == "__main__": + with transformer_engine_context(): + app.run(main) diff --git a/src/maxdiffusion/image_processor.py b/src/maxdiffusion/image_processor.py index 76fa7635..3d65c3a5 100644 --- a/src/maxdiffusion/image_processor.py +++ b/src/maxdiffusion/image_processor.py @@ -14,7 +14,7 @@ # limitations under the License. import warnings -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -93,8 +93,14 @@ class VaeImageProcessor(ConfigMixin): `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. vae_scale_factor (`int`, *optional*, defaults to `8`): VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. + spatial_patch_size (`Tuple[int, int]`, *optional*, defaults to `(1, 1)`): + Additional height/width patch multiple used by models that require image sizes aligned to + `vae_scale_factor * spatial_patch_size`. resample (`str`, *optional*, defaults to `lanczos`): Resampling filter to use when resizing the image. + resize_mode (`str`, *optional*, defaults to `"default"`): + Resize strategy for PIL images. `"default"` resizes directly to the target size. `"fill"` preserves aspect + ratio and letterboxes the image to the target size. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1]. do_binarize (`bool`, *optional*, defaults to `False`): @@ -103,6 +109,8 @@ class VaeImageProcessor(ConfigMixin): Whether to convert the images to RGB format. do_convert_grayscale (`bool`, *optional*, defaults to be `False`): Whether to convert the images to grayscale format. + fill_color (`Union[int, float, Tuple, str]`, *optional*, defaults to `0`): + Fill color used when `resize_mode="fill"`. """ config_name = CONFIG_NAME @@ -112,11 +120,14 @@ def __init__( self, do_resize: bool = True, vae_scale_factor: int = 8, + spatial_patch_size: Tuple[int, int] = (1, 1), resample: str = "lanczos", + resize_mode: str = "default", do_normalize: bool = True, do_binarize: bool = False, do_convert_rgb: bool = False, do_convert_grayscale: bool = False, + fill_color: Union[int, float, Tuple, str] = 0, ): super().__init__() if do_convert_rgb and do_convert_grayscale: @@ -126,6 +137,8 @@ def __init__( " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", ) self.config.do_convert_rgb = False + if resize_mode not in {"default", "fill"}: + raise ValueError(f"Unsupported resize_mode '{resize_mode}'. Expected one of: 'default', 'fill'.") @staticmethod def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: @@ -243,12 +256,35 @@ def get_default_height_width( else: width = image.shape[2] - width, height = ( - x - x % self.config.vae_scale_factor for x in (width, height) - ) # resize to integer multiple of vae_scale_factor + mod_h = self.config.vae_scale_factor * self.config.spatial_patch_size[0] + mod_w = self.config.vae_scale_factor * self.config.spatial_patch_size[1] + width, height = (x - x % mod for x, mod in ((width, mod_w), (height, mod_h))) return height, width + def resize_and_fill( + self, + image: PIL.Image.Image, + width: int, + height: int, + ) -> PIL.Image.Image: + """ + Resize a PIL image to fit within the target size while preserving aspect ratio, then letterbox it. + """ + target_ratio = width / height + source_ratio = image.width / image.height + + resized_width = width if target_ratio < source_ratio else image.width * height // image.height + resized_height = height if target_ratio >= source_ratio else image.height * width // image.width + + resized = image.resize((resized_width, resized_height), resample=PIL_INTERPOLATION[self.config.resample]) + canvas = PIL.Image.new("RGB", (width, height), color=self.config.fill_color) + canvas.paste( + resized, + box=(width // 2 - resized_width // 2, height // 2 - resized_height // 2), + ) + return canvas + def resize( self, image: [PIL.Image.Image, np.ndarray, torch.Tensor], @@ -259,7 +295,10 @@ def resize( Resize image. """ if isinstance(image, PIL.Image.Image): - image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) + if self.config.resize_mode == "fill": + image = self.resize_and_fill(image, width, height) + else: + image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) elif isinstance(image, torch.Tensor): image = torch.nn.functional.interpolate( image, diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index cc1d9ea1..82eb3cad 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -964,7 +964,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -978,7 +978,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -992,7 +992,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("embed",), + ("heads",), ), ) @@ -1006,7 +1006,7 @@ def __init__( precision=precision, bias_init=nnx.with_partitioning( nnx.initializers.zeros, - ("heads",), + ("embed",), ), ) @@ -1317,11 +1317,12 @@ def setup(self): precision=self.precision, ) + proj_attn_kernel_axes = ("heads", "embed") self.proj_attn = nn.Dense( self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes), use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), dtype=self.dtype, param_dtype=self.weights_dtype, name="i_proj", @@ -1330,9 +1331,9 @@ def setup(self): self.encoder_proj_attn = nn.Dense( self.query_dim, - kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), kernel_axes), + kernel_init=nn.with_logical_partitioning(nn.initializers.lecun_normal(), proj_attn_kernel_axes), use_bias=True, - bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("heads",)), + bias_init=nn.with_logical_partitioning(nn.initializers.zeros, ("embed",)), dtype=self.dtype, param_dtype=self.weights_dtype, name="e_proj", diff --git a/src/maxdiffusion/models/wan/transformers/__init__.py b/src/maxdiffusion/models/wan/transformers/__init__.py index 4a62083b..23735b2e 100644 --- a/src/maxdiffusion/models/wan/transformers/__init__.py +++ b/src/maxdiffusion/models/wan/transformers/__init__.py @@ -13,3 +13,5 @@ See the License for the specific language governing permissions and limitations under the License. """ + +from .transformer_wan_animate import NNXWanAnimateTransformer3DModel diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index e701ab92..ff7acb7b 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -193,11 +193,11 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - "mlp", "embed", + "mlp", ), ), - bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("mlp",)), ) def __call__(self, x: jax.Array) -> jax.Array: @@ -249,8 +249,8 @@ def __init__( kernel_init=nnx.with_partitioning( nnx.initializers.xavier_uniform(), ( - "embed", "mlp", + "embed", ), ), ) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py new file mode 100644 index 00000000..db6d5795 --- /dev/null +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py @@ -0,0 +1,1084 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Tuple, Optional, Dict, Union, Any +import contextlib +import math +import numpy as np +import jax +import jax.numpy as jnp +import flax.linen as nn +from flax import nnx +from .... import common_types +from ...modeling_flax_utils import FlaxModelMixin +from ....configuration_utils import ConfigMixin, register_to_config +from ...normalization_flax import FP32LayerNorm +from ...gradient_checkpoint import GradientCheckpointType +from .transformer_wan import ( + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, +) + +BlockSizes = common_types.BlockSizes + +WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES = { + "4": 512, + "8": 512, + "16": 512, + "32": 512, + "64": 256, + "128": 128, + "256": 64, + "512": 32, + "1024": 16, +} + + +class FusedLeakyReLU(nnx.Module): + """ + Fused LeakyRelu with scale factor and channel-wise bias. + """ + + def __init__( + self, + rngs: nnx.Rngs, + negative_slope: float = 0.2, + scale: float = 2**0.5, + bias_channels: Optional[int] = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + ): + self.negative_slope = negative_slope + self.scale = scale + self.channels = bias_channels + self.dtype = dtype + self.weights_dtype = weights_dtype + + if self.channels is not None: + self.bias = nnx.Param(jnp.zeros((self.channels,), dtype=self.weights_dtype)) + else: + self.bias = None + + def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array: + if self.bias is not None: + # Expand self.bias to have all singleton dims except at channel_dim + expanded_shape = [1] * x.ndim + expanded_shape[channel_dim] = self.channels + bias = jnp.reshape(self.bias, expanded_shape) + x = x + bias + x = jax.nn.leaky_relu(x, self.negative_slope) * self.scale + return x + + +class MotionConv2d(nnx.Module): + """2-D convolution with EqualizedLR scaling and optional FusedLeakyReLU. + + Weights are stored in PyTorch OIHW format (out, in, k, k) as raw nnx.Param + so that the weight-loading code in wan_utils.py can map them without + transposing. No sharding annotations are applied because this module is + part of the small motion encoder network. + """ + + def __init__( + self, + rngs: nnx.Rngs, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + blur_kernel: Optional[Tuple[int, ...]] = None, + blur_upsample_factor: int = 1, + use_activation: bool = True, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + ): + self.use_activation = use_activation + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding_size = padding + self.dtype = dtype + self.weights_dtype = weights_dtype + + self.blur = False + if blur_kernel is not None: + p = (len(blur_kernel) - stride) + (kernel_size - 1) + self.blur_padding = ((p + 1) // 2, p // 2) + + kernel = np.asarray(blur_kernel, dtype=np.float32) + if kernel.ndim == 1: + kernel = np.expand_dims(kernel, 0) * np.expand_dims(kernel, 1) + kernel = kernel / kernel.sum() + + if blur_upsample_factor > 1: + kernel = kernel * (blur_upsample_factor**2) + + self.blur_kernel = nnx.static(tuple(tuple(float(v) for v in row) for row in kernel)) + self.blur = True + else: + self.blur_kernel = nnx.static(None) + + key = rngs.params() + # Shape: (out_channels, in_channels, kernel, kernel) — PyTorch OIHW format. + self.weight = nnx.Param( + jax.random.normal(key, (out_channels, in_channels, kernel_size, kernel_size), dtype=weights_dtype) + ) + self.scale = 1.0 / math.sqrt(in_channels * kernel_size**2) + + if bias and not self.use_activation: + self.bias = nnx.Param(jnp.zeros((out_channels,), dtype=weights_dtype)) + else: + self.bias = None + + if self.use_activation: + self.act_fn = FusedLeakyReLU( + rngs=rngs, bias_channels=out_channels, dtype=dtype, weights_dtype=weights_dtype + ) + else: + self.act_fn = None + + def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array: + # 1. Blur Pass (Depthwise) + if self.blur: + blur_kernel = jnp.asarray(self.blur_kernel, dtype=jnp.float32) + expanded_kernel = jnp.expand_dims(jnp.expand_dims(blur_kernel, 0), 0) + expanded_kernel = jnp.broadcast_to( + expanded_kernel, + ( + self.in_channels, + 1, + expanded_kernel.shape[2], + expanded_kernel.shape[3], + ), + ) + x = x.astype(expanded_kernel.dtype) + + pad_h, pad_w = self.blur_padding + x = jax.lax.conv_general_dilated( + x, + expanded_kernel, + window_strides=(1, 1), + padding=[(pad_h, pad_h), (pad_w, pad_w)], + dimension_numbers=("NCHW", "OIHW", "NCHW"), + feature_group_count=self.in_channels, + ) + + # 2. Main Convolution Pass + x = x.astype(self.weight.dtype) + conv_weight = self.weight * self.scale + x = jax.lax.conv_general_dilated( + x, + conv_weight, + window_strides=(self.stride, self.stride), + padding=[ + (self.padding_size, self.padding_size), + (self.padding_size, self.padding_size), + ], + dimension_numbers=("NCHW", "OIHW", "NCHW"), + ) + + # 3. Bias and Activation + if self.bias is not None: + b = jnp.reshape(self.bias, (1, self.out_channels, 1, 1)) + x = x + b + + if self.use_activation: + x = self.act_fn(x, channel_dim=channel_dim) + + return x + + +class MotionLinear(nnx.Module): + """Equalized-LR linear layer with optional FusedLeakyReLU. + + Weights are stored in PyTorch (out, in) format as raw nnx.Param — same + reason as MotionConv2d. No sharding annotations needed (small layer). + """ + + def __init__( + self, + rngs: nnx.Rngs, + in_dim: int, + out_dim: int, + bias: bool = True, + use_activation: bool = False, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + ): + self.use_activation = use_activation + self.in_dim = in_dim + self.out_dim = out_dim + self.dtype = dtype + self.weights_dtype = weights_dtype + + key = rngs.params() + self.weight = nnx.Param(jax.random.normal(key, (out_dim, in_dim), dtype=weights_dtype)) + self.scale = 1.0 / math.sqrt(in_dim) + + if bias and not self.use_activation: + self.bias = nnx.Param(jnp.zeros((out_dim,), dtype=weights_dtype)) + else: + self.bias = None + + if self.use_activation: + self.act_fn = FusedLeakyReLU(rngs=rngs, bias_channels=out_dim, dtype=dtype, weights_dtype=weights_dtype) + else: + self.act_fn = None + + def __call__(self, inputs: jax.Array, channel_dim: int = 1) -> jax.Array: + inputs = inputs.astype(self.weight.dtype) + # Transpose to (in_dim, out_dim) and apply scale + w = self.weight.T * self.scale + + out = inputs @ w + + if self.bias is not None: + out = out + self.bias + + if self.use_activation: + out = self.act_fn(out, channel_dim=channel_dim) + + return out + + +class MotionEncoderResBlock(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + kernel_size_skip: int = 1, + blur_kernel: Tuple[int, ...] = (1, 3, 3, 1), + downsample_factor: int = 2, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + ): + self.downsample_factor = downsample_factor + self.dtype = dtype + + # 3 X 3 Conv + fused leaky ReLU + self.conv1 = MotionConv2d( + rngs, + in_channels, + in_channels, + kernel_size, + stride=1, + padding=kernel_size // 2, + use_activation=True, + dtype=dtype, + weights_dtype=weights_dtype, + ) + + # 3 X 3 Conv + downsample 2x + fused leaky ReLU + self.conv2 = MotionConv2d( + rngs, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=self.downsample_factor, + padding=0, + blur_kernel=blur_kernel, + use_activation=True, + dtype=dtype, + weights_dtype=weights_dtype, + ) + + # 1 X 1 Conv + downsample 2x in skip connection + self.conv_skip = MotionConv2d( + rngs, + in_channels, + out_channels, + kernel_size=kernel_size_skip, + stride=self.downsample_factor, + padding=0, + bias=False, + blur_kernel=blur_kernel, + use_activation=False, + dtype=dtype, + weights_dtype=weights_dtype, + ) + + def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array: + x_out = self.conv1(x, channel_dim=channel_dim) + x_out = self.conv2(x_out, channel_dim=channel_dim) + + x_skip = self.conv_skip(x, channel_dim=channel_dim) + + x_out = (x_out + x_skip) / math.sqrt(2.0) + return x_out + + +class WanAnimateMotionEncoder(nnx.Module): + """Encodes a face video frame into a motion vector. + + All weights in this network are small (the largest is 32×512→16) so + sharding annotations are not applied. + """ + + def __init__( + self, + rngs: nnx.Rngs, + size: int = 512, + style_dim: int = 512, + motion_dim: int = 20, + out_dim: int = 512, + motion_blocks: int = 5, + channels: Optional[Dict[str, int]] = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + ): + self.size = size + self.dtype = dtype + self.weights_dtype = weights_dtype + + if channels is None: + channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES + + self.conv_in = MotionConv2d( + rngs, 3, channels[str(size)], 1, use_activation=True, dtype=dtype, weights_dtype=weights_dtype + ) + + res_blocks = [] + in_channels = channels[str(size)] + log_size = int(math.log(size, 2)) + for i in range(log_size, 2, -1): + out_channels = channels[str(2 ** (i - 1))] + res_blocks.append( + MotionEncoderResBlock(rngs, in_channels, out_channels, dtype=dtype, weights_dtype=weights_dtype) + ) + in_channels = out_channels + self.res_blocks = nnx.List(res_blocks) + + self.conv_out = MotionConv2d( + rngs, + in_channels, + style_dim, + 4, + padding=0, + bias=False, + use_activation=False, + dtype=dtype, + weights_dtype=weights_dtype, + ) + + linears = [] + for _ in range(motion_blocks - 1): + linears.append(MotionLinear(rngs, style_dim, style_dim, dtype=dtype, weights_dtype=weights_dtype)) + + linears.append(MotionLinear(rngs, style_dim, motion_dim, dtype=dtype, weights_dtype=weights_dtype)) + self.motion_network = nnx.List(linears) + + key = rngs.params() + self.motion_synthesis_weight = nnx.Param(jax.random.normal(key, (out_dim, motion_dim), dtype=weights_dtype)) + + def __call__(self, face_image: jax.Array, channel_dim: int = 1) -> jax.Array: + if face_image.shape[-2] != self.size or face_image.shape[-1] != self.size: + raise ValueError(f"Expected {self.size} got {face_image.shape[-1]}") + + x = self.conv_in(face_image, channel_dim=channel_dim) + for block in self.res_blocks: + x = block(x, channel_dim=channel_dim) + x = self.conv_out(x, channel_dim=channel_dim) + + motion_feat = jnp.squeeze(x, axis=(-1, -2)) + + for linear_layer in self.motion_network: + motion_feat = linear_layer(motion_feat, channel_dim=channel_dim) + + weight = self.motion_synthesis_weight[...] + 1e-8 + + original_dtype = motion_feat.dtype + motion_feat_fp32 = motion_feat.astype(jnp.float32) + weight_fp32 = weight.astype(jnp.float32) + + Q, _ = jnp.linalg.qr(weight_fp32) + + motion_vec = jnp.matmul(motion_feat_fp32, jnp.transpose(Q, (1, 0))) + + return motion_vec.astype(original_dtype) + + +class WanAnimateFaceEncoder(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + in_dim: int, + out_dim: int, + hidden_dim: int = 1024, + num_heads: int = 4, + kernel_size: int = 3, + eps: float = 1e-6, + pad_mode: str = "edge", + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + ): + self.num_heads = num_heads + self.kernel_size = kernel_size + self.pad_mode = pad_mode + self.out_dim = out_dim + self.dtype = dtype + + self.act = jax.nn.silu + + self.conv1_local = nnx.Conv( + in_dim, + hidden_dim * num_heads, + kernel_size=(kernel_size,), + strides=(1,), + padding="VALID", + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + ) + self.conv2 = nnx.Conv( + hidden_dim, + hidden_dim, + kernel_size=(kernel_size,), + strides=(2,), + padding="VALID", + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + ) + self.conv3 = nnx.Conv( + hidden_dim, + hidden_dim, + kernel_size=(kernel_size,), + strides=(2,), + padding="VALID", + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + ) + + self.norm1 = nnx.LayerNorm( + hidden_dim, + epsilon=eps, + use_bias=False, + use_scale=False, + rngs=rngs, + dtype=dtype, + ) + self.norm2 = nnx.LayerNorm( + hidden_dim, + epsilon=eps, + use_bias=False, + use_scale=False, + rngs=rngs, + dtype=dtype, + ) + self.norm3 = nnx.LayerNorm( + hidden_dim, + epsilon=eps, + use_bias=False, + use_scale=False, + rngs=rngs, + dtype=dtype, + ) + + # hidden_dim (mlp) → out_dim (embed): ("mlp", "embed") + self.out_proj = nnx.Linear( + hidden_dim, + out_dim, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("mlp", "embed")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + + self.padding_tokens = nnx.Param(jnp.zeros((1, 1, 1, out_dim), dtype=weights_dtype)) + + def __call__(self, x: jax.Array) -> jax.Array: + batch_size = x.shape[0] + + # Local attention via causal convolution + x = jnp.pad(x, ((0, 0), (self.kernel_size - 1, 0), (0, 0)), mode=self.pad_mode) + x = self.conv1_local(x) + + x = jnp.reshape(x, (batch_size, x.shape[1], self.num_heads, -1)) + x = jnp.transpose(x, (0, 2, 1, 3)) + x = jnp.reshape(x, (batch_size * self.num_heads, x.shape[2], x.shape[3])) + + x = self.norm1(x) + x = self.act(x) + + x = jnp.pad(x, ((0, 0), (self.kernel_size - 1, 0), (0, 0)), mode=self.pad_mode) + x = self.conv2(x) + x = self.norm2(x) + x = self.act(x) + + x = jnp.pad(x, ((0, 0), (self.kernel_size - 1, 0), (0, 0)), mode=self.pad_mode) + x = self.conv3(x) + x = self.norm3(x) + x = self.act(x) + + x = self.out_proj(x) + + x = jnp.reshape(x, (batch_size, self.num_heads, x.shape[1], x.shape[2])) + x = jnp.transpose(x, (0, 2, 1, 3)) + + padding = jnp.broadcast_to(self.padding_tokens[...], (batch_size, x.shape[1], 1, self.out_dim)) + x = jnp.concatenate([x, padding], axis=2) + + return x + + +class WanAnimateFaceBlockCrossAttention(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-6, + cross_attention_dim_head: Optional[int] = None, + use_bias: bool = True, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + ): + self.heads = heads + self.inner_dim = dim_head * heads + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + self.dtype = dtype + + self.pre_norm_q = nnx.LayerNorm(dim, epsilon=eps, use_bias=False, use_scale=False, rngs=rngs, dtype=dtype) + self.pre_norm_kv = nnx.LayerNorm(dim, epsilon=eps, use_bias=False, use_scale=False, rngs=rngs, dtype=dtype) + + # embed → heads + self.to_q = nnx.Linear( + dim, + self.inner_dim, + use_bias=use_bias, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("heads",)), + ) + self.to_k = nnx.Linear( + dim, + self.kv_inner_dim, + use_bias=use_bias, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("heads",)), + ) + self.to_v = nnx.Linear( + dim, + self.kv_inner_dim, + use_bias=use_bias, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("embed", "heads")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("heads",)), + ) + + # heads → embed + self.to_out = nnx.Linear( + self.inner_dim, + dim, + use_bias=use_bias, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), + ) + + self.norm_q = nnx.RMSNorm(dim_head, epsilon=eps, use_scale=True, rngs=rngs, dtype=dtype, param_dtype=weights_dtype) + self.norm_k = nnx.RMSNorm(dim_head, epsilon=eps, use_scale=True, rngs=rngs, dtype=dtype, param_dtype=weights_dtype) + + def __call__( + self, + hidden_states: jax.Array, + encoder_hidden_states: jax.Array, + attention_mask: Optional[jax.Array] = None, + ) -> jax.Array: + hidden_states = self.pre_norm_q(hidden_states) + encoder_hidden_states = self.pre_norm_kv(encoder_hidden_states) + + B, T, N, C = encoder_hidden_states.shape + + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # Reshape to extract heads + query = jnp.reshape(query, (query.shape[0], query.shape[1], self.heads, -1)) + key = jnp.reshape(key, (B, T, N, self.heads, -1)) + value = jnp.reshape(value, (B, T, N, self.heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + query_S = query.shape[1] + + # Fold Time into the Batch dimension for attention + query = jnp.reshape(query, (B * T, query_S // T, self.heads, -1)) + key = jnp.reshape(key, (B * T, N, self.heads, -1)) + value = jnp.reshape(value, (B * T, N, self.heads, -1)) + + attn_output = jax.nn.dot_product_attention(query, key, value) + + # Restore (Batch, Total Sequence, Dim) + attn_output = jnp.reshape(attn_output, (B, query_S, -1)) + + hidden_states = self.to_out(attn_output) + + if attention_mask is not None: + attention_mask = jnp.reshape(attention_mask, (attention_mask.shape[0], -1)) + hidden_states = hidden_states * jnp.expand_dims(attention_mask, axis=-1) + + return hidden_states + + +class NNXWanAnimateTransformer3DModel(nnx.Module, FlaxModelMixin, ConfigMixin): + + @register_to_config + def __init__( + self, + rngs: nnx.Rngs, + model_type="t2v", + patch_size: Tuple[int, int, int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 36, + latent_channels: int = 16, + out_channels: Optional[int] = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + dropout: float = 0.0, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = 1280, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, + image_seq_len: Optional[int] = None, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + remat_policy: str = "None", + names_which_can_be_saved: list = [], + names_which_can_be_offloaded: list = [], + mask_padding_tokens: bool = True, + scan_layers: bool = True, + enable_jax_named_scopes: bool = False, + motion_encoder_channel_sizes: Optional[Dict[str, int]] = None, + motion_encoder_size: int = 512, + motion_style_dim: int = 512, + motion_dim: int = 20, + motion_encoder_dim: int = 512, + face_encoder_hidden_dim: int = 1024, + face_encoder_num_heads: int = 4, + inject_face_latents_blocks: int = 5, + motion_encoder_batch_size: int = 8, + ): + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or latent_channels + + self.num_layers = num_layers + self.scan_layers = scan_layers + self.enable_jax_named_scopes = enable_jax_named_scopes + self.patch_size = patch_size + self.inject_face_latents_blocks = inject_face_latents_blocks + self.motion_encoder_batch_size = motion_encoder_batch_size + self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) + self.names_which_can_be_saved = names_which_can_be_saved + self.names_which_can_be_offloaded = names_which_can_be_offloaded + + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + + # Patch embeddings — shard output (conv_out) axis across model parallelism. + self.patch_embedding = nnx.Conv( + in_channels, + inner_dim, + kernel_size=patch_size, + strides=patch_size, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + (None, None, None, None, "conv_out"), + ), + ) + self.pose_patch_embedding = nnx.Conv( + latent_channels, + inner_dim, + kernel_size=patch_size, + strides=patch_size, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + kernel_init=nnx.with_partitioning( + nnx.initializers.xavier_uniform(), + (None, None, None, None, "conv_out"), + ), + ) + + self.condition_embedder = WanTimeTextImageEmbedding( + rngs=rngs, + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + flash_min_seq_length=flash_min_seq_length, + dtype=dtype, + weights_dtype=weights_dtype, + ) + + self.motion_encoder = WanAnimateMotionEncoder( + rngs=rngs, + size=motion_encoder_size, + style_dim=motion_style_dim, + motion_dim=motion_dim, + out_dim=motion_encoder_dim, + channels=motion_encoder_channel_sizes, + dtype=dtype, + weights_dtype=weights_dtype, + ) + self.face_encoder = WanAnimateFaceEncoder( + rngs=rngs, + in_dim=motion_encoder_dim, + out_dim=inner_dim, + hidden_dim=face_encoder_hidden_dim, + num_heads=face_encoder_num_heads, + dtype=dtype, + weights_dtype=weights_dtype, + ) + + @nnx.split_rngs(splits=num_layers) + @nnx.vmap(in_axes=0, out_axes=0, transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}) + def init_block(rngs): + return WanTransformerBlock( + rngs=rngs, + dim=inner_dim, + ffn_dim=ffn_dim, + num_heads=num_attention_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + image_seq_len=image_seq_len, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention=attention, + dropout=dropout, + mask_padding_tokens=mask_padding_tokens, + enable_jax_named_scopes=enable_jax_named_scopes, + ) + + if scan_layers: + self.blocks = init_block(rngs) + else: + blocks = [] + for _ in range(num_layers): + block = WanTransformerBlock( + rngs=rngs, + dim=inner_dim, + ffn_dim=ffn_dim, + num_heads=num_attention_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + image_seq_len=image_seq_len, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention=attention, + dropout=dropout, + mask_padding_tokens=mask_padding_tokens, + enable_jax_named_scopes=enable_jax_named_scopes, + ) + blocks.append(block) + self.blocks = nnx.List(blocks) + + face_adapters = [] + num_face_adapters = math.ceil(num_layers / inject_face_latents_blocks) + for _ in range(num_face_adapters): + fa = WanAnimateFaceBlockCrossAttention( + rngs=rngs, + dim=inner_dim, + heads=num_attention_heads, + dim_head=inner_dim // num_attention_heads, + eps=eps, + cross_attention_dim_head=inner_dim // num_attention_heads, + dtype=dtype, + weights_dtype=weights_dtype, + ) + face_adapters.append(fa) + self.face_adapter = nnx.List(face_adapters) + + self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) + + # Final projection — embed → output tokens. + self.proj_out = nnx.Linear( + rngs=rngs, + in_features=inner_dim, + out_features=out_channels * math.prod(patch_size), + dtype=dtype, + param_dtype=weights_dtype, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", None)), + ) + + key = rngs.params() + self.scale_shift_table = nnx.Param( + jax.random.normal(key, (1, 2, inner_dim), dtype=weights_dtype) / inner_dim**0.5, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "embed")), + ) + + def conditional_named_scope(self, name: str): + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + + def _apply_face_adapter(self, hidden_states: jax.Array, motion_vec: Optional[jax.Array], block_idx) -> jax.Array: + if motion_vec is None or len(self.face_adapter) == 0: + return hidden_states + + adapter_idx = block_idx // self.inject_face_latents_blocks + adapter_branches = tuple( + (lambda current_hidden_states, adapter=adapter: current_hidden_states + adapter(current_hidden_states, motion_vec)) + for adapter in self.face_adapter + ) + return jax.lax.switch(adapter_idx, adapter_branches, hidden_states) + + @jax.named_scope("WanAnimateTransformer3DModel") + def __call__( + self, + hidden_states: jax.Array, + timestep: jax.Array, + encoder_hidden_states: jax.Array, + encoder_hidden_states_image: Optional[jax.Array] = None, + pose_hidden_states: Optional[jax.Array] = None, + face_pixel_values: Optional[jax.Array] = None, + motion_encode_batch_size: Optional[int] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + deterministic: bool = True, + rngs: nnx.Rngs = None, + ) -> Union[jax.Array, Dict[str, jax.Array]]: + if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]: + raise ValueError( + f"Pose frames + 1 ({pose_hidden_states.shape[2]} + 1) must equal hidden_states frames ({hidden_states.shape[2]})" + ) + + # Constrain input to batch-sharded layout before any computation. + hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None)) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # 1 & 2. Rotary Position & Patch Embedding + hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) + rotary_emb = self.rope(hidden_states) + hidden_states = self.patch_embedding(hidden_states) + + pose_hidden_states = jnp.transpose(pose_hidden_states, (0, 2, 3, 4, 1)) + pose_hidden_states = self.pose_patch_embedding(pose_hidden_states) + pose_pad = jnp.zeros( + ( + batch_size, + 1, + pose_hidden_states.shape[2], + pose_hidden_states.shape[3], + pose_hidden_states.shape[4], + ), + dtype=hidden_states.dtype, + ) + pose_pad = jnp.concatenate([pose_pad, pose_hidden_states], axis=1) + hidden_states = hidden_states + pose_pad + + hidden_states = jnp.reshape(hidden_states, (batch_size, -1, hidden_states.shape[-1])) + + # 3. Condition Embeddings + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + encoder_attention_mask, + ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) + timestep_proj = timestep_proj.reshape(batch_size, 6, -1) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) + + # 4. Batched Face & Motion Encoding + _, face_channels, num_face_frames, face_height, face_width = face_pixel_values.shape + + # Rearrange from (B, C, T, H, W) to (B*T, C, H, W) + face_pixel_values = jnp.transpose(face_pixel_values, (0, 2, 1, 3, 4)) + face_pixel_values = jnp.reshape(face_pixel_values, (-1, face_channels, face_height, face_width)) + + total_face_frames = face_pixel_values.shape[0] + motion_encode_batch_size = motion_encode_batch_size or self.motion_encoder_batch_size + + # Pad sequence if it doesn't divide evenly by encode_bs + pad_len = (motion_encode_batch_size - (total_face_frames % motion_encode_batch_size)) % motion_encode_batch_size + if pad_len > 0: + pad_tensor = jnp.zeros( + (pad_len, face_channels, face_height, face_width), + dtype=face_pixel_values.dtype, + ) + face_pixel_values = jnp.concatenate([face_pixel_values, pad_tensor], axis=0) + + # Reshape into chunks for scan + num_chunks = face_pixel_values.shape[0] // motion_encode_batch_size + face_chunks = jnp.reshape( + face_pixel_values, + ( + num_chunks, + motion_encode_batch_size, + face_channels, + face_height, + face_width, + ), + ) + + # Use jax.lax.scan to iterate over chunks to save memory + def encode_chunk_fn(carry, chunk): + encoded_chunk = self.motion_encoder(chunk) + return carry, encoded_chunk + + _, motion_vec_chunks = jax.lax.scan(encode_chunk_fn, None, face_chunks) + motion_vec = jnp.reshape(motion_vec_chunks, (-1, motion_vec_chunks.shape[-1])) + + # Remove padding if added + if pad_len > 0: + motion_vec = motion_vec[:-pad_len] + + motion_vec = jnp.reshape(motion_vec, (batch_size, num_face_frames, -1)) + + # Apply face encoder + motion_vec = self.face_encoder(motion_vec) + pad_face = jnp.zeros_like(motion_vec[:, :1]) + motion_vec = jnp.concatenate([pad_face, motion_vec], axis=1) + + # 5. Transformer Blocks + if self.scan_layers: + + def scan_fn(carry, block_idx, block): + hidden_states_carry, rngs_carry = carry + hidden_states = block( + hidden_states_carry, + encoder_hidden_states, + timestep_proj, + rotary_emb, + deterministic, + rngs_carry, + encoder_attention_mask=encoder_attention_mask, + ) + + hidden_states = jax.lax.cond( + block_idx % self.inject_face_latents_blocks == 0, + lambda current_hidden_states: self._apply_face_adapter(current_hidden_states, motion_vec, block_idx), + lambda current_hidden_states: current_hidden_states, + hidden_states, + ) + return (hidden_states, rngs_carry), None + + rematted_block_forward = self.gradient_checkpoint.apply( + scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + ) + initial_carry = (hidden_states, rngs) + final_carry, _ = nnx.scan( + rematted_block_forward, + length=self.num_layers, + in_axes=(nnx.Carry, 0, 0), + out_axes=(nnx.Carry, 0), + )(initial_carry, jnp.arange(self.num_layers), self.blocks) + hidden_states, _ = final_carry + else: + for block_idx, block in enumerate(self.blocks): + + def layer_forward(hidden_states): + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + deterministic, + rngs, + encoder_attention_mask=encoder_attention_mask, + ) + + if motion_vec is not None and block_idx % self.inject_face_latents_blocks == 0: + hidden_states = self._apply_face_adapter(hidden_states, motion_vec, block_idx) + return hidden_states + + rematted_layer_forward = self.gradient_checkpoint.apply( + layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + ) + hidden_states = rematted_layer_forward(hidden_states) + + # 6. Output Norm & Projection + shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) + hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) + hidden_states = self.proj_out(hidden_states) + + hidden_states = jnp.reshape( + hidden_states, + ( + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + p_t, + p_h, + p_w, + -1, + ), + ) + hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6)) + hidden_states = jnp.reshape(hidden_states, (batch_size, -1, num_frames, height, width)) + + if not return_dict: + return (hidden_states,) + return {"sample": hidden_states} + diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index aa731923..5fd41567 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -39,6 +39,22 @@ def _tuple_str_to_int(in_tuple): return tuple(out_list) +def _normalize_animate_list_key(key): + """Convert flattened animate list names into nnx.List-style tuple paths.""" + if not key: + return key + + if isinstance(key[0], str) and key[0].startswith("face_adapter_"): + adapter_idx = int(key[0].split("_")[-1]) + return ("face_adapter", adapter_idx) + key[1:] + + if len(key) >= 2 and key[0] == "motion_encoder" and isinstance(key[1], str) and key[1].startswith("motion_network_"): + layer_idx = int(key[1].split("_")[-1]) + return ("motion_encoder", "motion_network", layer_idx) + key[2:] + + return key + + def rename_for_nnx(key): new_key = key if "norm_k" in key or "norm_q" in key: @@ -73,24 +89,92 @@ def rename_for_custom_trasformer(key): def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers=40): + block_index = None if scan_layers: - if "blocks" in pt_tuple_key: - new_key = ("blocks",) + pt_tuple_key[2:] + if len(pt_tuple_key) >= 2 and pt_tuple_key[0] == "blocks": block_index = int(pt_tuple_key[1]) - pt_tuple_key = new_key + pt_tuple_key = ("blocks",) + pt_tuple_key[2:] flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers) flax_key = rename_for_nnx(flax_key) flax_key = _tuple_str_to_int(flax_key) - if scan_layers: - if "blocks" in flax_key: - if flax_key in flax_state_dict: - new_tensor = flax_state_dict[flax_key] - else: - new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) - flax_tensor = new_tensor.at[block_index].set(flax_tensor) + if scan_layers and block_index is not None: + if flax_key in flax_state_dict: + new_tensor = flax_state_dict[flax_key] + else: + new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape, dtype=flax_tensor.dtype) + flax_tensor = new_tensor.at[block_index].set(flax_tensor) + return flax_key, flax_tensor + + +def _build_random_flax_state_dict(eval_shapes): + flattened_dict = flatten_dict(eval_shapes) + random_flax_state_dict = {} + for key, value in flattened_dict.items(): + random_flax_state_dict[tuple(str(item) for item in key)] = value + return random_flax_state_dict + + +def _rename_common_wan_transformer_key(renamed_pt_key: str) -> str: + if "condition_embedder" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "time_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "time_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_projection_1", "time_proj") + renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "text_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "text_embedder.linear_2") + + if "image_embedder" in renamed_pt_key: + if "net.0.proj" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0") + elif "net_0.proj" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0") + if "net.2" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net.2", "net_2") + renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm") + if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("weight", "scale") + renamed_pt_key = renamed_pt_key.replace("kernel", "scale") + + renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") + renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table") + renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") + renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out") + renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") + renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") + + return renamed_pt_key + + +def _rename_wan_animate_pt_tuple_key(pt_key: str): + renamed_pt_key = _rename_common_wan_transformer_key(rename_key(pt_key)) + is_motion_custom_weight = _is_motion_encoder_custom_weight(pt_key) + + renamed_pt_key = renamed_pt_key.replace(".activation.bias", ".act_fn.bias") + if is_motion_custom_weight and renamed_pt_key.endswith(".kernel"): + renamed_pt_key = renamed_pt_key[:-7] + ".weight" + + return tuple(renamed_pt_key.split(".")), is_motion_custom_weight + + +def get_wan_animate_key_and_value( + pt_tuple_key, + tensor, + flax_state_dict, + random_flax_state_dict, + scan_layers, + is_motion_custom_weight=False, + num_layers=40, +): + if is_motion_custom_weight: + flax_key = _normalize_animate_list_key(_tuple_str_to_int(pt_tuple_key)) + return flax_key, tensor + + flax_key, flax_tensor = get_key_and_value( + pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers + ) + flax_key = _normalize_animate_list_key(flax_key) return flax_key, flax_tensor @@ -248,45 +332,15 @@ def load_base_wan_transformer( tensors[k] = torch2jax(f.get_tensor(k)) flax_state_dict = {} cpu = jax.local_devices(backend="cpu")[0] - flattened_dict = flatten_dict(eval_shapes) # turn all block numbers to strings just for matching weights. # Later they will be turned back to ints. - random_flax_state_dict = {} - for key in flattened_dict: - string_tuple = tuple([str(item) for item in key]) - random_flax_state_dict[string_tuple] = flattened_dict[key] - del flattened_dict + random_flax_state_dict = _build_random_flax_state_dict(eval_shapes) for pt_key, tensor in tensors.items(): # The diffusers implementation explicitly describes this key in keys to be ignored. if "norm_added_q" in pt_key: continue renamed_pt_key = rename_key(pt_key) - - if "condition_embedder" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "time_embedder.linear_1") - renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "time_embedder.linear_2") - renamed_pt_key = renamed_pt_key.replace("time_projection_1", "time_proj") - renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "text_embedder.linear_1") - renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "text_embedder.linear_2") - - if "image_embedder" in renamed_pt_key: - if "net.0.proj" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0") - elif "net_0.proj" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0") - if "net.2" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("net.2", "net_2") - renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm") - if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key: - renamed_pt_key = renamed_pt_key.replace("weight", "scale") - renamed_pt_key = renamed_pt_key.replace("kernel", "scale") - - renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") - renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table") - renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") - renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out") - renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") - renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") + renamed_pt_key = _rename_common_wan_transformer_key(renamed_pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) flax_key, flax_tensor = get_key_and_value( pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers @@ -300,6 +354,101 @@ def load_base_wan_transformer( return flax_state_dict +def _is_motion_encoder_custom_weight(pt_key: str) -> bool: + """Returns True for FlaxMotionConv2d/FlaxMotionLinear weight keys that must NOT be renamed to kernel.""" + prefixes = ( + "motion_encoder.conv_in.", + "motion_encoder.conv_out.", + ) + if any(pt_key.startswith(p) for p in prefixes) and pt_key.endswith(".weight"): + return True + if "motion_encoder.res_blocks." in pt_key and pt_key.endswith(".weight"): + return True + if "motion_encoder.motion_network." in pt_key and pt_key.endswith(".weight"): + return True + return False + + +def load_wan_animate_transformer( + pretrained_model_name_or_path: str, + eval_shapes: dict, + device: str, + hf_download: bool = True, + num_layers: int = 40, + scan_layers: bool = True, + subfolder: str = "transformer", +): + """Loads WanAnimate transformer weights from a HuggingFace checkpoint. + + Handles the additional key mappings for: + - pose_patch_embedding (nnx.Conv3d → kernel) + - motion_encoder.* (FlaxMotionConv2d/FlaxMotionLinear → keep as 'weight', no transpose) + - activation.bias → act_fn.bias (FusedLeakyReLU bias remapping) + - face_encoder.* (nnx.Conv/Linear → standard rename to kernel) + - face_adapter.* (nnx.Linear → standard rename to kernel) + """ + device = jax.local_devices(backend=device)[0] + filename = "diffusion_pytorch_model.safetensors.index.json" + local_files = False + if os.path.isdir(pretrained_model_name_or_path): + index_file_path = os.path.join(pretrained_model_name_or_path, subfolder, filename) + if not os.path.isfile(index_file_path): + raise FileNotFoundError(f"File {index_file_path} not found for local directory.") + local_files = True + elif hf_download: + index_file_path = hf_hub_download( + pretrained_model_name_or_path, + subfolder=subfolder, + filename=filename, + ) + with jax.default_device(device): + with open(index_file_path, "r") as f: + index_dict = json.load(f) + model_files = set() + for key in index_dict["weight_map"].keys(): + model_files.add(index_dict["weight_map"][key]) + + model_files = list(model_files) + tensors = {} + for model_file in model_files: + if local_files: + ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file) + else: + ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file) + max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}") + if ckpt_shard_path is not None: + with safe_open(ckpt_shard_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + random_flax_state_dict = _build_random_flax_state_dict(eval_shapes) + + for pt_key, tensor in tensors.items(): + if "norm_added_q" in pt_key: + continue + + pt_tuple_key, is_motion_custom_weight = _rename_wan_animate_pt_tuple_key(pt_key) + flax_key, flax_tensor = get_wan_animate_key_and_value( + pt_tuple_key, + tensor, + flax_state_dict, + random_flax_state_dict, + scan_layers, + is_motion_custom_weight=is_motion_custom_weight, + num_layers=num_layers, + ) + + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + return flax_state_dict + + def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] subfolder = "vae" diff --git a/src/maxdiffusion/pipelines/wan/__init__.py b/src/maxdiffusion/pipelines/wan/__init__.py index 9a17b1e7..c6d00bf4 100644 --- a/src/maxdiffusion/pipelines/wan/__init__.py +++ b/src/maxdiffusion/pipelines/wan/__init__.py @@ -15,3 +15,4 @@ """ from .wan_pipeline import WanPipeline +from .wan_pipeline_animate import WanAnimatePipeline diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 86c9f9c2..cdcd8f22 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -605,10 +605,14 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): components["tokenizer"] = cls.load_tokenizer(config=config) components["text_encoder"] = cls.load_text_encoder(config=config) components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) - if i2v and config.model_name == "wan2.1": + if cls._needs_image_encoder(config, i2v=i2v): components["image_processor"], components["image_encoder"] = cls.load_image_encoder(config) return components + @classmethod + def _needs_image_encoder(cls, config: HyperParameters, i2v: bool = False) -> bool: + return i2v and config.model_name == "wan2.1" + @abstractmethod def _get_num_channel_latents(self) -> int: """Returns the number of input channels for the transformer.""" diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py new file mode 100644 index 00000000..4e08fba1 --- /dev/null +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py @@ -0,0 +1,1037 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAX/Flax pipeline for character animation using Wan-Animate. + +Ported from: + https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan_animate.py + +The pipeline supports two modes: + - "animate": Generate a video of the reference character mimicking motion from + pose/face videos. + - "replace": Replace a character in a background video with the reference + character, using pose/face videos for motion control. + +Inference runs in segments of `segment_frame_length` frames (default 77), which +are stitched together with overlap conditioning from the previous segment. +""" + +from copy import deepcopy +from functools import partial +from typing import List, Optional, Tuple, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import PIL +import torch +from flax import nnx +from flax.linen import partitioning as nn_partitioning +from jax.sharding import NamedSharding, PartitionSpec as P +from maxdiffusion import max_logging +from maxdiffusion.image_processor import PipelineImageInput, VaeImageProcessor +from maxdiffusion.max_utils import device_put_replicated, get_flash_block_sizes, get_precision +from maxdiffusion.video_processor import VideoProcessor + +from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan +from ...models.wan.transformers.transformer_wan_animate import NNXWanAnimateTransformer3DModel +from ...models.wan.wan_utils import load_wan_animate_transformer +from ...pyconfig import HyperParameters +from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler +from .wan_pipeline import WanPipeline, cast_with_exclusion + + +def create_sharded_animate_transformer( + devices_array: np.ndarray, + mesh, + rngs: nnx.Rngs, + config: HyperParameters, + restored_checkpoint=None, + subfolder: str = "transformer", +) -> NNXWanAnimateTransformer3DModel: + """Creates a sharded NNXWanAnimateTransformer3DModel on device. + + Follows the same pattern as create_sharded_logical_transformer in + wan_pipeline.py but uses NNXWanAnimateTransformer3DModel and the + animate-specific weight loader. + """ + + def _create_model(rngs: nnx.Rngs, wan_config: dict): + return NNXWanAnimateTransformer3DModel(**wan_config, rngs=rngs) + + # 1. Load config. + if restored_checkpoint: + wan_config = restored_checkpoint["wan_config"] + else: + wan_config = NNXWanAnimateTransformer3DModel.load_config( + config.pretrained_model_name_or_path, subfolder=subfolder + ) + + wan_config["mesh"] = mesh + wan_config["dtype"] = config.activations_dtype + wan_config["weights_dtype"] = config.weights_dtype + wan_config["attention"] = config.attention + wan_config["precision"] = get_precision(config) + wan_config["flash_block_sizes"] = get_flash_block_sizes(config) + wan_config["remat_policy"] = config.remat_policy + wan_config["names_which_can_be_saved"] = config.names_which_can_be_saved + wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded + wan_config["flash_min_seq_length"] = config.flash_min_seq_length + wan_config["dropout"] = config.dropout + wan_config["mask_padding_tokens"] = config.mask_padding_tokens + wan_config["scan_layers"] = config.scan_layers + wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes + + # 2. eval_shape – creates the model structure without allocating HBM. + p_model_factory = partial(_create_model, wan_config=wan_config) + wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs) + graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) + + # 3. Retrieve logical-to-mesh sharding mappings. + logical_state_spec = nnx.get_partition_spec(state) + logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) + logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) + params = state.to_pure_dict() + state = dict(nnx.to_flat_state(state)) + + # 4. Load and shard pretrained weights. + if restored_checkpoint: + if "params" in restored_checkpoint["wan_state"]: + params = restored_checkpoint["wan_state"]["params"] + else: + params = restored_checkpoint["wan_state"] + else: + params = load_wan_animate_transformer( + config.wan_transformer_pretrained_model_name_or_path, + params, + "cpu", + num_layers=wan_config["num_layers"], + scan_layers=config.scan_layers, + subfolder=subfolder, + ) + + params = jax.tree_util.tree_map_with_path( + lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params + ) + for path, val in flax.traverse_util.flatten_dict(params).items(): + if restored_checkpoint: + path = path[:-1] + sharding = logical_state_sharding[path].value + state[path].value = device_put_replicated(val, sharding) + state = nnx.from_flat_state(state) + + wan_transformer = nnx.merge(graphdef, state, rest_of_state) + return wan_transformer + + +# --------------------------------------------------------------------------- +# JIT-compiled transformer forward pass +# --------------------------------------------------------------------------- + + +@partial(jax.jit, static_argnames=("motion_encode_batch_size",)) +def animate_transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents: jnp.ndarray, + reference_latents: jnp.ndarray, + pose_latents: jnp.ndarray, + face_video_segment: jnp.ndarray, + timestep: jnp.ndarray, + encoder_hidden_states: jnp.ndarray, + encoder_hidden_states_image: jnp.ndarray, + motion_encode_batch_size: Optional[int] = None, +) -> jnp.ndarray: + """Single denoising step for WanAnimate. + + Args: + latents: Noisy latents, shape (B, T_lat+1, H_lat, W_lat, z_dim), channel-last. + reference_latents: Reference image + prev-seg conditioning, + shape (B, T_lat+1, H_lat, W_lat, z_dim+4), channel-last. + pose_latents: VAE-encoded pose video, shape (B, T_lat, H_lat, W_lat, z_dim), + channel-last. + face_video_segment: Raw face video pixels, + shape (B, 3, T_segment, face_size, face_size), channel-first. + encoder_hidden_states: Text embeddings. + encoder_hidden_states_image: CLIP image embeddings. + + Returns: + noise_pred: Predicted noise, shape (B, T_lat+1, H_lat, W_lat, z_dim), + channel-last. + """ + # Build the full input: cat noisy latents and reference on the channel dim. + # latents: (B, T+1, H, W, z_dim) + # reference_latents: (B, T+1, H, W, z_dim+4) + # → (B, T+1, H, W, 2*z_dim+4 = 36) + latent_model_input = jnp.concatenate([latents, reference_latents], axis=-1) + # Transpose to channel-first for the transformer: (B, 36, T+1, H, W) + latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)).astype(encoder_hidden_states.dtype) + + # Pose latents channel-first: (B, z_dim, T_lat, H_lat, W_lat) + pose_latents_cf = jnp.transpose(pose_latents, (0, 4, 1, 2, 3)).astype(encoder_hidden_states.dtype) + + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + output = wan_transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=encoder_hidden_states_image, + pose_hidden_states=pose_latents_cf, + face_pixel_values=face_video_segment, + motion_encode_batch_size=motion_encode_batch_size, + return_dict=False, + ) + + # Transpose back to channel-last: (B, T+1, H, W, z_dim) + noise_pred = jnp.transpose(output[0], (0, 2, 3, 4, 1)) + return noise_pred + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +class WanAnimatePipeline(WanPipeline): + """JAX/Flax pipeline for Wan-Animate character animation. + + Supports two modes: + - "animate": Animate the reference character using pose and face videos. + - "replace": Replace a character in a background video using a mask. + + Inference is performed in temporal segments to handle arbitrary video lengths. + Each segment denoises `segment_frame_length` frames, with overlap conditioning + from the last few frames of the previous segment. + + Args: + config: HyperParameters configuration. + transformer: NNXWanAnimateTransformer3DModel instance (may be None for + VAE-only mode). + **kwargs: Passed to WanPipeline.__init__ (tokenizer, text_encoder, vae, etc.) + """ + + def __init__( + self, + config: HyperParameters, + transformer: Optional[NNXWanAnimateTransformer3DModel], + **kwargs, + ): + super().__init__(config=config, **kwargs) + self.transformer = transformer + spatial_patch_size = self.transformer.config.patch_size[-2:] if self.transformer is not None else (2, 2) + self.ref_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, + spatial_patch_size=spatial_patch_size, + resample="bilinear", + resize_mode="fill", + fill_color=0, + ) + self.video_processor_for_mask = VideoProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, + do_normalize=False, + do_convert_grayscale=True, + ) + + @classmethod + def _needs_image_encoder(cls, config: HyperParameters, i2v: bool = False) -> bool: + return True + + @classmethod + def load_animate_transformer( + cls, + devices_array: np.ndarray, + mesh, + rngs: nnx.Rngs, + config: HyperParameters, + restored_checkpoint=None, + subfolder: str = "transformer", + ) -> NNXWanAnimateTransformer3DModel: + with mesh: + return create_sharded_animate_transformer( + devices_array=devices_array, + mesh=mesh, + rngs=rngs, + config=config, + restored_checkpoint=restored_checkpoint, + subfolder=subfolder, + ) + + @classmethod + def _load_and_init( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only: bool = False, + load_transformer: bool = True, + ) -> Tuple["WanAnimatePipeline", Optional[NNXWanAnimateTransformer3DModel]]: + common_components = cls._create_common_components(config, vae_only) + transformer = None + if not vae_only and load_transformer: + transformer = cls.load_animate_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) + pipeline = cls( + config=config, + transformer=transformer, + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + image_processor=common_components["image_processor"], + image_encoder=common_components["image_encoder"], + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + ) + return pipeline, transformer + + @classmethod + def from_pretrained( + cls, + config: HyperParameters, + vae_only: bool = False, + load_transformer: bool = True, + ) -> "WanAnimatePipeline": + pipeline, transformer = cls._load_and_init(config, None, vae_only, load_transformer) + pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) + return pipeline + + @classmethod + def from_checkpoint( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only: bool = False, + load_transformer: bool = True, + ) -> "WanAnimatePipeline": + pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + return pipeline + + # ------------------------------------------------------------------ + # Abstract method implementation + # ------------------------------------------------------------------ + + def _get_num_channel_latents(self) -> int: + return self.vae.z_dim + + # ------------------------------------------------------------------ + # Video utilities + # ------------------------------------------------------------------ + + def check_inputs( + self, + prompt, + negative_prompt, + image, + pose_video, + face_video, + background_video, + mask_video, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + mode=None, + prev_segment_conditioning_frames=None, + ): + """Validate user-facing pipeline inputs with Diffusers-compatible checks.""" + supported_image_types = (torch.Tensor, PIL.Image.Image, np.ndarray, jnp.ndarray) + + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError("Provide either `image` or `image_embeds`. Cannot leave both undefined.") + if image is not None and not isinstance(image, supported_image_types): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, or `jnp.ndarray` but is {type(image)}" + ) + if pose_video is None: + raise ValueError("Provide `pose_video`. Cannot leave `pose_video` undefined.") + if face_video is None: + raise ValueError("Provide `face_video`. Cannot leave `face_video` undefined.") + if not isinstance(pose_video, list) or not isinstance(face_video, list): + raise ValueError("`pose_video` and `face_video` must be lists of PIL images.") + if len(pose_video) == 0 or len(face_video) == 0: + raise ValueError("`pose_video` and `face_video` must contain at least one frame.") + if mode == "replace" and (background_video is None or mask_video is None): + raise ValueError( + "Provide `background_video` and `mask_video`. Cannot leave both `background_video` and `mask_video`" + " undefined when mode is `replace`." + ) + if mode == "replace" and (not isinstance(background_video, list) or not isinstance(mask_video, list)): + raise ValueError("`background_video` and `mask_video` must be lists of PIL images when mode is `replace`.") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + if prompt is None and prompt_embeds is None: + raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both undefined.") + if prompt is not None and not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and not isinstance(negative_prompt, (str, list)): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if mode is not None and (not isinstance(mode, str) or mode not in ("animate", "replace")): + raise ValueError( + f"`mode` has to be of type `str` and in ('animate', 'replace') but its type is {type(mode)} and value is {mode}" + ) + if prev_segment_conditioning_frames is not None and ( + not isinstance(prev_segment_conditioning_frames, int) or prev_segment_conditioning_frames not in (1, 5) + ): + raise ValueError( + f"`prev_segment_conditioning_frames` has to be of type `int` and 1 or 5 but its type is" + f" {type(prev_segment_conditioning_frames)} and value is {prev_segment_conditioning_frames}" + ) + + @staticmethod + def pad_video_frames(frames: list, num_target_frames: int) -> list: + """Pad *frames* to *num_target_frames* using a reflect-like strategy. + + Example: pad_video_frames([1,2,3,4,5], 10) → [1,2,3,4,5,4,3,2,1,2] + """ + idx = 0 + flip = False + target_frames = [] + while len(target_frames) < num_target_frames: + target_frames.append(deepcopy(frames[idx])) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == len(frames) - 1: + flip = not flip + return target_frames + + # ------------------------------------------------------------------ + # I2V mask helpers + # ------------------------------------------------------------------ + + def get_i2v_mask( + self, + batch_size: int, + latent_t: int, + latent_h: int, + latent_w: int, + mask_len: int = 1, + mask_pixel_values: Optional[jnp.ndarray] = None, + dtype: jnp.dtype = jnp.float32, + ) -> jnp.ndarray: + """Construct an I2V conditioning mask in channel-last format. + + A mask value of 1 means "this frame is known/conditioned" and 0 means + "this frame is freely generated". + + Args: + latent_t: Number of latent temporal frames. + mask_pixel_values: Optional pre-computed mask at pixel temporal resolution + but latent spatial resolution, shape (B, 1, T_pixel, H_lat, W_lat). + T_pixel = (latent_t - 1) * vae_scale_factor_temporal + 1. + mask_len: Number of leading frames to force to 1 (known). + + Returns: + Mask array of shape (B, latent_t, H_lat, W_lat, vae_scale_factor_temporal). + """ + vae_scale = self.vae_scale_factor_temporal + pixel_frames = (latent_t - 1) * vae_scale + 1 + + if mask_pixel_values is None: + mask_lat_size = jnp.zeros((batch_size, 1, pixel_frames, latent_h, latent_w), dtype=dtype) + else: + mask_lat_size = mask_pixel_values.astype(dtype) + + # Set the first mask_len pixel frames to 1 (conditioned). + mask_lat_size = mask_lat_size.at[:, :, :mask_len, :, :].set(1.0) + + # Repeat the first frame vae_scale times so total frames = latent_t * vae_scale. + first_frame = mask_lat_size[:, :, 0:1, :, :] # (B, 1, 1, H, W) + first_frame = jnp.repeat(first_frame, vae_scale, axis=2) # (B, 1, vae_scale, H, W) + mask_lat_size = jnp.concatenate([first_frame, mask_lat_size[:, :, 1:, :, :]], axis=2) + # (B, 1, latent_t*vae_scale, H, W) + + # Reshape: (B, 1, latent_t*vae_scale, H, W) → (B, latent_t, vae_scale, H, W) + mask_lat_size = mask_lat_size.reshape(batch_size, latent_t, vae_scale, latent_h, latent_w) + # Transpose to (B, vae_scale, latent_t, H, W) then to (B, latent_t, H, W, vae_scale). + mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 1, 3, 4)) # (B, vae_scale, T, H, W) + mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 3, 4, 1)) # (B, T, H, W, vae_scale) + return mask_lat_size + + # ------------------------------------------------------------------ + # Latent preparation helpers + # ------------------------------------------------------------------ + + def _encode_video_to_latents( + self, + video: jnp.ndarray, + dtype: jnp.dtype, + ) -> jnp.ndarray: + """Encode a video tensor and normalize the latents. + + Args: + video: (B, C, T, H, W) channel-first, values in [-1, 1]. + + Returns: + Normalized latents: (B, T_lat, H_lat, W_lat, z_dim) channel-last. + """ + vae_dtype = getattr(self.vae, "dtype", jnp.float32) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + encoded = self.vae.encode(video.astype(vae_dtype), self.vae_cache)[0].mode() + # Normalize + mean = jnp.array(self.vae.latents_mean).reshape(1, 1, 1, 1, self.vae.z_dim) + std = jnp.array(self.vae.latents_std).reshape(1, 1, 1, 1, self.vae.z_dim) + latents = (encoded - mean) / std + return latents.astype(dtype) + + def prepare_reference_image_latents( + self, + image: jnp.ndarray, + batch_size: int, + dtype: jnp.dtype, + ) -> jnp.ndarray: + """Encode the reference character image and prepend an I2V mask. + + Args: + image: (B, C, H, W) or (B, C, 1, H, W) channel-first, values in [-1, 1]. + + Returns: + (B, 1, H_lat, W_lat, z_dim + vae_scale_factor_temporal) channel-last. + """ + if image.ndim == 4: + image = image[:, :, jnp.newaxis, :, :] # (B, C, 1, H, W) + + # Encode the single reference frame. + ref_latents = self._encode_video_to_latents(image, dtype) # (B, 1, H_lat, W_lat, z_dim) + + if ref_latents.shape[0] == 1 and batch_size > 1: + ref_latents = jnp.broadcast_to(ref_latents, (batch_size,) + ref_latents.shape[1:]) + + latent_h = ref_latents.shape[2] + latent_w = ref_latents.shape[3] + + # Mask for the single reference frame — mark it as fully conditioned. + ref_mask = self.get_i2v_mask(batch_size, 1, latent_h, latent_w, mask_len=1, dtype=dtype) + # (B, 1, H_lat, W_lat, vae_scale) + + return jnp.concatenate([ref_mask, ref_latents], axis=-1) + + def _resize_mask_to_latent_spatial( + self, + mask: jnp.ndarray, + latent_h: int, + latent_w: int, + ) -> jnp.ndarray: + """Resize a mask from pixel spatial resolution to latent spatial resolution. + + Args: + mask: (B, 1, T, H, W) channel-first. + + Returns: + (B, 1, T, H_lat, W_lat) channel-first. + """ + B, C, T, H, W = mask.shape + if H == latent_h and W == latent_w: + return mask + # Match torch.nn.functional.interpolate(..., mode="nearest") exactly. + h_indices = jnp.floor(jnp.arange(latent_h) * (H / latent_h)).astype(jnp.int32) + w_indices = jnp.floor(jnp.arange(latent_w) * (W / latent_w)).astype(jnp.int32) + mask = jnp.take(mask, h_indices, axis=3) + return jnp.take(mask, w_indices, axis=4) + + def prepare_prev_segment_cond_latents( + self, + prev_segment_cond_video: Optional[jnp.ndarray], + background_video: Optional[jnp.ndarray], + mask_video: Optional[jnp.ndarray], + batch_size: int, + segment_frame_length: int, + start_frame: int, + height: int, + width: int, + prev_segment_cond_frames: int, + task: str, + dtype: jnp.dtype, + ) -> jnp.ndarray: + """Prepare latent conditioning from the previous segment. + + Args: + prev_segment_cond_video: Last N decoded frames from the previous segment, + shape (B, C, N, H, W) channel-first in [-1, 1], or None for segment 0. + background_video: Background video segment for replace mode, + shape (B, C, T_seg, H, W). + mask_video: Mask video segment for replace mode (white=generate, black=preserve), + shape (B, 1, T_seg, H, W). + start_frame: Pixel-space start frame of the current segment. + task: "animate" or "replace". + + Returns: + (B, T_lat, H_lat, W_lat, z_dim + vae_scale_factor_temporal) channel-last. + """ + vae_dtype = getattr(self.vae, "dtype", jnp.float32) + latent_h = height // self.vae_scale_factor_spatial + latent_w = width // self.vae_scale_factor_spatial + num_latent_frames = (segment_frame_length - 1) // self.vae_scale_factor_temporal + 1 + + if prev_segment_cond_video is None: + if task == "replace" and background_video is not None: + prev_segment_cond_video = background_video[:, :, :prev_segment_cond_frames] + else: + prev_segment_cond_video = jnp.zeros( + (batch_size, 3, prev_segment_cond_frames, height, width), dtype=vae_dtype + ) + + # Build full-length cond video (prev frames + remainder). + if task == "replace" and background_video is not None: + remaining = background_video[:, :, prev_segment_cond_frames:] + else: + remaining_frames = segment_frame_length - prev_segment_cond_frames + remaining = jnp.zeros((batch_size, 3, remaining_frames, height, width), dtype=vae_dtype) + + full_cond_video = jnp.concatenate( + [prev_segment_cond_video.astype(vae_dtype), remaining], axis=2 + ) # (B, C, T_seg, H, W) + + cond_latents = self._encode_video_to_latents(full_cond_video, dtype) + # (B, T_lat, H_lat, W_lat, z_dim) + + # Build I2V mask. + if task == "replace" and mask_video is not None: + # Invert mask: white (1.0, generate) → 0.0, black (0.0, preserve) → 1.0. + # In the I2V mask convention, 1 = known/conditioned, 0 = freely generated. + inverted_mask = 1.0 - mask_video + mask_pixel_values = self._resize_mask_to_latent_spatial(inverted_mask, latent_h, latent_w) + # mask_pixel_values: (B, 1, T_seg, H_lat, W_lat) – pixel temporal resolution + else: + mask_pixel_values = None + + cond_mask = self.get_i2v_mask( + batch_size, + num_latent_frames, + latent_h, + latent_w, + mask_len=prev_segment_cond_frames if start_frame > 0 else 0, + mask_pixel_values=mask_pixel_values, + dtype=dtype, + ) + # (B, T_lat, H_lat, W_lat, vae_scale) + + return jnp.concatenate([cond_mask, cond_latents], axis=-1) + + def prepare_pose_latents( + self, + pose_video: jnp.ndarray, + batch_size: int, + dtype: jnp.dtype, + ) -> jnp.ndarray: + """Encode the pose video segment to latents. + + Args: + pose_video: (B, C, T_seg, H, W) channel-first, values in [-1, 1]. + + Returns: + (B, T_lat, H_lat, W_lat, z_dim) channel-last. + """ + pose_latents = self._encode_video_to_latents(pose_video, dtype) + if pose_latents.shape[0] == 1 and batch_size > 1: + pose_latents = jnp.broadcast_to(pose_latents, (batch_size,) + pose_latents.shape[1:]) + return pose_latents + + def prepare_segment_latents( + self, + batch_size: int, + height: int, + width: int, + segment_frame_length: int, + dtype: jnp.dtype, + rng: jax.Array, + latents: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + """Sample noisy latents for a denoising segment. + + The +1 accounts for the reference frame slot at index 0. + + Returns: + (B, T_lat+1, H_lat, W_lat, z_dim) channel-last. + """ + num_latent_frames = (segment_frame_length - 1) // self.vae_scale_factor_temporal + 1 + latent_h = height // self.vae_scale_factor_spatial + latent_w = width // self.vae_scale_factor_spatial + shape = (batch_size, num_latent_frames + 1, latent_h, latent_w, self.vae.z_dim) + if latents is not None: + latents = jnp.asarray(latents) + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape {latents.shape}; expected {shape}.") + return latents.astype(dtype) + return jax.random.normal(rng, shape=shape, dtype=jnp.float32).astype(dtype) + + def _decode_segment_to_pixels(self, latents_cl: jnp.ndarray) -> jnp.ndarray: + """Decode latents and return raw pixel-space frames for re-encoding. + + Args: + latents_cl: (B, T_lat, H_lat, W_lat, z_dim) channel-last, normalised. + + Returns: + (B, C, T, H, W) channel-first, values in [-1, 1] (VAE output range). + """ + latents_cf = jnp.transpose(latents_cl, (0, 4, 1, 2, 3)) # (B, z_dim, T, H, W) + latents_cf = self._denormalize_latents(latents_cf) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + video_cl = self.vae.decode(latents_cf, self.vae_cache)[0] # (B, T, H, W, C) + return jnp.transpose(video_cl, (0, 4, 1, 2, 3)) # (B, C, T, H, W) + + # ------------------------------------------------------------------ + # Main inference + # ------------------------------------------------------------------ + + def __call__( + self, + image: PipelineImageInput, + pose_video: List[PIL.Image.Image], + face_video: List[PIL.Image.Image], + background_video: Optional[List[PIL.Image.Image]] = None, + mask_video: Optional[List[PIL.Image.Image]] = None, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + segment_frame_length: int = 77, + num_inference_steps: int = 20, + mode: str = "animate", + prev_segment_conditioning_frames: int = 1, + motion_encode_batch_size: Optional[int] = None, + guidance_scale: float = 1.0, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + latents: Optional[jnp.ndarray] = None, + prompt_embeds: Optional[jnp.ndarray] = None, + negative_prompt_embeds: Optional[jnp.ndarray] = None, + image_embeds: Optional[jnp.ndarray] = None, + output_type: Optional[str] = "np", + rng: Optional[jax.Array] = None, + ): + """Run the Wan-Animate inference pipeline. + + Args: + image: Reference character image (PIL.Image or compatible). + pose_video: List of PIL frames representing the pose video. + face_video: List of PIL frames representing the face video. + background_video: (replace mode) Background video frames. + mask_video: (replace mode) Mask frames. White=generate, black=preserve. + prompt: Text prompt(s). + negative_prompt: Negative prompt(s) for CFG (only used when guidance_scale > 1). + height: Output video height in pixels. + width: Output video width in pixels. + segment_frame_length: Number of frames per denoising segment. Should satisfy + (segment_frame_length - 1) % vae_scale_factor_temporal == 0. + num_inference_steps: Denoising steps per segment. + mode: "animate" or "replace". + prev_segment_conditioning_frames: Overlap frames between segments (1 or 5). + motion_encode_batch_size: Batch size for the motion encoder. Defaults to + the transformer's configured value. + guidance_scale: CFG scale (set > 1 to enable classifier-free guidance). + num_videos_per_prompt: Number of videos to generate per prompt. + rng: Optional JAX PRNG key. + + Returns: + If output_type == "np": numpy array of shape (B, T, H, W, C) in [0, 1]. + If output_type == "latent": raw latents from the final segment. + """ + height = height or self.config.height + width = width or self.config.width + + self.check_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + image=image, + pose_video=pose_video, + face_video=face_video, + background_video=background_video, + mask_video=mask_video, + height=height, + width=width, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + mode=mode, + prev_segment_conditioning_frames=prev_segment_conditioning_frames, + ) + + # Ensure segment_frame_length satisfies the VAE temporal constraint. + if segment_frame_length % self.vae_scale_factor_temporal != 1: + max_logging.log( + f"`segment_frame_length - 1` must be divisible by {self.vae_scale_factor_temporal}. " + f"Rounding {segment_frame_length}." + ) + segment_frame_length = ( + segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + ) + segment_frame_length = max(segment_frame_length, 1) + + do_classifier_free_guidance = guidance_scale > 1.0 + + # Determine batch size. + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + effective_batch_size = batch_size * num_videos_per_prompt + + # Segment arithmetic. + cond_video_frames = len(pose_video) + effective_segment_length = segment_frame_length - prev_segment_conditioning_frames + last_frames = (cond_video_frames - prev_segment_conditioning_frames) % effective_segment_length + num_padding_frames = 0 if last_frames == 0 else effective_segment_length - last_frames + num_target_frames = cond_video_frames + num_padding_frames + num_segments = num_target_frames // effective_segment_length + + # ---- 1. Encode prompts ---- + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + transformer_dtype = self.config.activations_dtype + latent_dtype = jnp.float32 + prompt_embeds = prompt_embeds.astype(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.astype(transformer_dtype) + + # ---- 2. Encode reference image with CLIP ---- + if image_embeds is None: + image_embeds = self.encode_image(image, num_videos_per_prompt=effective_batch_size) + image_embeds = image_embeds.astype(transformer_dtype) + + # ---- 3. VAE-encode reference image ---- + # Use VaeImageProcessor with resize_mode="fill" so the character is letterboxed instead of cropped. + image_tensor = self.ref_image_processor.preprocess(image, height=height, width=width) + image_tensor = jnp.array(image_tensor.cpu().numpy()) + if image_tensor.ndim == 3: + image_tensor = image_tensor[None] # (1, C, H, W) + if effective_batch_size > 1 and image_tensor.shape[0] == 1: + image_tensor = jnp.broadcast_to(image_tensor, (effective_batch_size,) + image_tensor.shape[1:]) + + reference_image_latents = self.prepare_reference_image_latents( + image_tensor, effective_batch_size, transformer_dtype + ) # (B, 1, H_lat, W_lat, z_dim+vae_scale) + + # ---- 4. Preprocess conditioning videos ---- + pose_video = self.pad_video_frames(pose_video, num_target_frames) + face_video = self.pad_video_frames(face_video, num_target_frames) + + pose_video_tensor = self.video_processor.preprocess_video(pose_video, height=height, width=width) + pose_video_tensor = jnp.array(pose_video_tensor.cpu().numpy()) # (1, C, T, H, W) + + face_size = self.transformer.motion_encoder.size + face_video_tensor = self.video_processor.preprocess_video( + face_video, height=face_size, width=face_size + ) + face_video_tensor = jnp.array(face_video_tensor.cpu().numpy()) # (1, C, T, face_size, face_size) + + background_video_tensor = None + mask_video_tensor = None + if mode == "replace": + if background_video is None or mask_video is None: + raise ValueError("`background_video` and `mask_video` are required for replace mode.") + background_video = self.pad_video_frames(background_video, num_target_frames) + mask_video = self.pad_video_frames(mask_video, num_target_frames) + + background_video_tensor = self.video_processor.preprocess_video( + background_video, height=height, width=width + ) + background_video_tensor = jnp.array(background_video_tensor.cpu().numpy()) + mask_video_tensor = self.video_processor_for_mask.preprocess_video(mask_video, height=height, width=width) + mask_video_tensor = jnp.array(mask_video_tensor.cpu().numpy()) + + if rng is None: + rng = jax.random.key(self.config.seed) + + # ---- 5. Device placement ---- + data_sharding = NamedSharding(self.mesh, P()) + if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: + data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) + + prompt_embeds = jax.device_put(prompt_embeds, data_sharding) + if negative_prompt_embeds is not None: + negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) + image_embeds = jax.device_put(image_embeds, data_sharding) + + graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + + # ---- 6. Segment denoising loop ---- + start = 0 + end = segment_frame_length + all_out_frames_cf = [] # list of (B, C, T, H, W) channel-first in [-1,1] + out_frames_cf = None # decoded output from previous segment + + for _seg in range(num_segments): + rng, latents_rng = jax.random.split(rng) + + seg_latents = self.prepare_segment_latents( + effective_batch_size, + height, + width, + segment_frame_length, + latent_dtype, + latents_rng, + latents=latents if start == 0 else None, + ) # (B, T_lat+1, H_lat, W_lat, z_dim) + + # Extract segment slices. + pose_seg = pose_video_tensor[:, :, start:end] # (1, C, T_seg, H, W) + face_seg = face_video_tensor[:, :, start:end] # (1, C, T_seg, face_size, face_size) + + if effective_batch_size > 1: + pose_seg = jnp.broadcast_to(pose_seg, (effective_batch_size,) + pose_seg.shape[1:]) + face_seg = jnp.broadcast_to(face_seg, (effective_batch_size,) + face_seg.shape[1:]) + face_seg = face_seg.astype(transformer_dtype) + + # Previous segment conditioning frames (pixel space, channel-first, [-1,1]). + prev_cond_video = None + if start > 0 and out_frames_cf is not None: + prev_cond_video = out_frames_cf[:, :, -prev_segment_conditioning_frames:] + + # Encode pose and prepare prev-seg conditioning. + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + pose_latents = self.prepare_pose_latents(pose_seg, effective_batch_size, transformer_dtype) + + bg_seg = None + mask_seg = None + if mode == "replace": + bg_seg = background_video_tensor[:, :, start:end] + mask_seg = mask_video_tensor[:, :, start:end] + if effective_batch_size > 1: + bg_seg = jnp.broadcast_to(bg_seg, (effective_batch_size,) + bg_seg.shape[1:]) + mask_seg = jnp.broadcast_to(mask_seg, (effective_batch_size,) + mask_seg.shape[1:]) + + prev_seg_cond_latents = self.prepare_prev_segment_cond_latents( + prev_segment_cond_video=prev_cond_video, + background_video=bg_seg, + mask_video=mask_seg, + batch_size=effective_batch_size, + segment_frame_length=segment_frame_length, + start_frame=start, + height=height, + width=width, + prev_segment_cond_frames=prev_segment_conditioning_frames, + task=mode, + dtype=transformer_dtype, + ) # (B, T_lat, H_lat, W_lat, z_dim+vae_scale) + + # Combine reference (1 frame) + prev-seg conditioning (T_lat frames). + reference_latents = jnp.concatenate( + [reference_image_latents, prev_seg_cond_latents], axis=1 + ) # (B, T_lat+1, H_lat, W_lat, z_dim+vae_scale) + + # Set up scheduler timesteps for this segment. + scheduler_state = self.scheduler.set_timesteps( + self.scheduler_state, num_inference_steps=num_inference_steps, shape=seg_latents.shape + ) + + seg_latents = jax.device_put(seg_latents, data_sharding) + reference_latents = jax.device_put(reference_latents, data_sharding) + pose_latents = jax.device_put(pose_latents, data_sharding) + face_seg = jax.device_put(face_seg, data_sharding) + + # Denoising loop. + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, (seg_latents.shape[0],)) + + noise_pred = animate_transformer_forward_pass( + graphdef, + state, + rest_of_state, + seg_latents, + reference_latents, + pose_latents, + face_seg, + timestep, + prompt_embeds, + image_embeds, + motion_encode_batch_size=motion_encode_batch_size, + ) + + if do_classifier_free_guidance: + # Blank face pixels (all -1) for the unconditional pass. + face_seg_uncond = face_seg * 0 - 1 + noise_uncond = animate_transformer_forward_pass( + graphdef, + state, + rest_of_state, + seg_latents, + reference_latents, + pose_latents, + face_seg_uncond, + timestep, + negative_prompt_embeds, + image_embeds, + motion_encode_batch_size=motion_encode_batch_size, + ) + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + noise_pred = noise_pred.astype(seg_latents.dtype) + seg_latents, scheduler_state = self.scheduler.step( + scheduler_state, noise_pred, t, seg_latents, return_dict=False + ) + + # Decode this segment (skip reference frame at index 0). + out_frames_cf = self._decode_segment_to_pixels(seg_latents[:, 1:, :, :, :]) + # (B, C, T_pixel, H, W) channel-first in [-1, 1] + + if start > 0: + # Drop overlap frames used for conditioning. + out_frames_cf_trimmed = out_frames_cf[:, :, prev_segment_conditioning_frames:] + else: + out_frames_cf_trimmed = out_frames_cf + + all_out_frames_cf.append(out_frames_cf_trimmed) + + start += effective_segment_length + end += effective_segment_length + + # ---- 7. Assemble output ---- + # Concat along the temporal dimension and trim to the original video length. + video_cf = jnp.concatenate(all_out_frames_cf, axis=2)[:, :, :cond_video_frames] + # (B, C, T, H, W) channel-first in [-1, 1] + + if output_type == "latent": + return seg_latents + + # Postprocess to [0, 1] numpy. + video_torch = torch.from_numpy(np.array(video_cf.astype(jnp.float32))).to(torch.bfloat16) + return self.video_processor.postprocess_video(video_torch, output_type="np") diff --git a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py index b2c7d96a..892e9be1 100644 --- a/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py +++ b/src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py @@ -238,10 +238,11 @@ def set_timesteps( # timesteps = jnp.array([_sigma_to_t_jax(s, log_sigmas_full) for s in sigmas]).round().astype(jnp.int64) raise NotImplementedError("`use_beta_sigmas` is not implemented in JAX version yet.") if self.config.use_flow_sigmas: - alphas = jnp.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) - sigmas = 1.0 - alphas - sigmas = jnp.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() - timesteps = (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int64) + sigmas = jnp.linspace(1.0, 1.0 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1] + sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas) + eps = 1e-6 + sigmas = sigmas.at[0].set(jnp.where(jnp.abs(sigmas[0] - 1.0) < eps, sigmas[0] - eps, sigmas[0])) + timesteps = (sigmas * self.config.num_train_timesteps).copy().astype(jnp.int32) if self.config.final_sigmas_type == "sigma_min": sigma_last = sigmas[-1] elif self.config.final_sigmas_type == "zero": diff --git a/src/maxdiffusion/tests/wan_animate_diffusers_parity_test.py b/src/maxdiffusion/tests/wan_animate_diffusers_parity_test.py new file mode 100644 index 00000000..da4c0b0b --- /dev/null +++ b/src/maxdiffusion/tests/wan_animate_diffusers_parity_test.py @@ -0,0 +1,944 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import unittest +from contextlib import nullcontext +from types import SimpleNamespace +from unittest import mock + +os.environ.setdefault("JAX_PLATFORMS", "cpu") +os.environ.setdefault("MPLCONFIGDIR", "/tmp/mplconfig") + +import jax +import jax.numpy as jnp +import numpy as np +import PIL.Image +import pytest +import torch +import torch.nn.functional as F +from jax.sharding import Mesh + +from diffusers.pipelines.wan.image_processor import WanAnimateImageProcessor as HFWanAnimateImageProcessor +from diffusers.pipelines.wan.pipeline_wan_animate import WanAnimatePipeline as HFWanAnimatePipeline +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from diffusers.video_processor import VideoProcessor as HFVideoProcessor + +from maxdiffusion.image_processor import VaeImageProcessor as MaxVaeImageProcessor +from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline as MaxWanPipeline +from maxdiffusion.pipelines.wan.wan_pipeline_animate import ( + WanAnimatePipeline as MaxWanAnimatePipeline, + animate_transformer_forward_pass, +) +from maxdiffusion.schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler +from maxdiffusion.video_processor import VideoProcessor as MaxVideoProcessor + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + + +def to_numpy(array): + if isinstance(array, torch.Tensor): + if array.dtype == torch.bfloat16: + array = array.float() + return array.detach().cpu().numpy() + return np.asarray(array) + + +def hf_channel_first_to_last(array): + return np.transpose(to_numpy(array), (0, 2, 3, 4, 1)) + + +class FakeTokenBatch: + + def __init__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): + self.input_ids = input_ids + self.attention_mask = attention_mask + + +class FakeTokenizer: + + def __call__( + self, + prompt, + padding, + max_length, + truncation, + add_special_tokens, + return_attention_mask, + return_tensors, + ): + del padding, truncation, add_special_tokens, return_attention_mask, return_tensors + input_ids = [] + attention_mask = [] + for text in prompt: + seq_len = max(1, min(max_length, len(text.split()) + 1)) + base = (sum(ord(ch) for ch in text) % 37) + 1 + ids = torch.arange(base, base + seq_len, dtype=torch.long) + pad = torch.zeros(max_length - seq_len, dtype=torch.long) + input_ids.append(torch.cat([ids, pad], dim=0)) + attention_mask.append( + torch.cat([torch.ones(seq_len, dtype=torch.long), torch.zeros(max_length - seq_len, dtype=torch.long)], dim=0) + ) + return FakeTokenBatch(torch.stack(input_ids), torch.stack(attention_mask)) + + +class FakeTextEncoder: + dtype = torch.float32 + + def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor): + hidden = torch.stack( + [ + input_ids.float(), + input_ids.float() * 0.5 + attention_mask.float(), + attention_mask.float() * 2.0, + input_ids.float() - attention_mask.float(), + ], + dim=-1, + ) + return SimpleNamespace(last_hidden_state=hidden) + + +class FakeImageBatch(dict): + + @property + def pixel_values(self): + return self["pixel_values"] + + def to(self, device=None): + pixel_values = self.pixel_values if device is None else self.pixel_values.to(device) + return FakeImageBatch(pixel_values=pixel_values) + + +class FakeImageProcessor: + + def __call__(self, images, return_tensors): + del return_tensors + if not isinstance(images, list): + images = [images] + pixel_values = [] + for image in images: + image_array = np.asarray(image, dtype=np.float32) + if image_array.ndim == 2: + image_array = image_array[..., None] + if image_array.shape[-1] > 1: + image_array = image_array.mean(axis=-1, keepdims=True) + pixel_value = torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0) / 255.0 + pixel_value = F.interpolate(pixel_value, size=(2, 2), mode="bilinear", align_corners=False) + pixel_values.append(pixel_value.squeeze(0)) + return FakeImageBatch(pixel_values=torch.stack(pixel_values)) + + +class FakeImageEncoder: + + def __call__(self, pixel_values, output_hidden_states: bool): + del output_hidden_states + hidden = pixel_values.reshape(pixel_values.shape[0], pixel_values.shape[1], -1) + if isinstance(pixel_values, torch.Tensor): + hidden = hidden.transpose(1, 2) + else: + hidden = jnp.transpose(hidden, (0, 2, 1)) + return SimpleNamespace(hidden_states=[hidden * 0.25, hidden * 0.5, hidden * 0.75]) + + +class FakeTorchLatentDist: + + def __init__(self, latents: torch.Tensor): + self._latents = latents + + def sample(self, generator=None): + del generator + return self._latents + + def mode(self): + return self._latents + + +class FakeTorchEncodeOutput: + + def __init__(self, latents: torch.Tensor): + self.latent_dist = FakeTorchLatentDist(latents) + + +class FakeTorchVAE: + dtype = torch.float32 + + class config: + z_dim = 2 + latents_mean = [0.5, -0.25] + latents_std = [2.0, 4.0] + + def encode(self, x: torch.Tensor): + latents = x[:, :2, ::4, ::8, ::8] + 1.0 + return FakeTorchEncodeOutput(latents) + + +class FakeJaxEncodeOutput: + + def __init__(self, latents: jnp.ndarray): + self._latents = latents + + def mode(self): + return self._latents + + +class FakeJaxVAE: + dtype = jnp.float32 + z_dim = 2 + latents_mean = [0.5, -0.25] + latents_std = [2.0, 4.0] + + def encode(self, x: jnp.ndarray, cache): + del cache + latents = jnp.transpose(x[:, :2, ::4, ::8, ::8] + 1.0, (0, 2, 3, 4, 1)) + return (FakeJaxEncodeOutput(latents),) + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run WAN parity tests on Github Actions") +class WanAnimateDiffusersParityTest(unittest.TestCase): + + def setUp(self): + self.np_rng = np.random.default_rng(0) + self.torch_generator = torch.Generator().manual_seed(0) + + self.max_pipeline = MaxWanAnimatePipeline.__new__(MaxWanAnimatePipeline) + self.max_pipeline.tokenizer = FakeTokenizer() + self.max_pipeline.text_encoder = FakeTextEncoder() + self.max_pipeline.image_processor = FakeImageProcessor() + self.max_pipeline.image_encoder = FakeImageEncoder() + self.max_pipeline.vae = FakeJaxVAE() + self.max_pipeline.vae_scale_factor_temporal = 4 + self.max_pipeline.vae_scale_factor_spatial = 8 + self.max_pipeline.mesh = nullcontext() + self.max_pipeline.vae_mesh = nullcontext() + self.max_pipeline.config = SimpleNamespace(logical_axis_rules=()) + self.max_pipeline.vae_logical_axis_rules = () + self.max_pipeline.vae_cache = None + self.max_pipeline.video_processor_for_mask = MaxVideoProcessor( + vae_scale_factor=8, do_normalize=False, do_convert_grayscale=True + ) + + self.hf_pipeline = HFWanAnimatePipeline.__new__(HFWanAnimatePipeline) + self.hf_pipeline.tokenizer = self.max_pipeline.tokenizer + self.hf_pipeline.text_encoder = self.max_pipeline.text_encoder + self.hf_pipeline.image_processor = self.max_pipeline.image_processor + self.hf_pipeline.image_encoder = self.max_pipeline.image_encoder + self.hf_pipeline.vae = FakeTorchVAE() + self.hf_pipeline.vae_scale_factor_temporal = 4 + self.hf_pipeline.vae_scale_factor_spatial = 8 + self.hf_pipeline.video_processor_for_mask = HFVideoProcessor( + vae_scale_factor=8, do_normalize=False, do_convert_grayscale=True + ) + + def _random_float_array(self, shape, low=-1.0, high=1.0): + return self.np_rng.uniform(low, high, size=shape).astype(np.float32) + + def _random_jax_array(self, shape, low=-1.0, high=1.0): + return jnp.array(self._random_float_array(shape, low=low, high=high)) + + def _random_torch_tensor(self, shape, low=-1.0, high=1.0): + tensor = torch.rand(shape, generator=self.torch_generator, dtype=torch.float32) + return tensor * (high - low) + low + + def _random_mask_tensor(self, shape, threshold=0.5): + return (torch.rand(shape, generator=self.torch_generator, dtype=torch.float32) > threshold).float() + + def _random_rgb_image(self, height, width): + pixels = self.np_rng.integers(0, 256, size=(height, width, 3), dtype=np.uint8) + return PIL.Image.fromarray(pixels) + + def _random_mask_image(self, height, width): + pixels = self.np_rng.integers(0, 256, size=(height, width), dtype=np.uint8) + return PIL.Image.fromarray(pixels) + + def _configure_pipeline_for_call_test(self): + devices = np.array(jax.devices()) + self.max_pipeline.mesh = Mesh(devices.reshape((devices.size,)), ("data",)) + self.max_pipeline.config = SimpleNamespace( + logical_axis_rules=(), + height=16, + width=16, + seed=0, + global_batch_size_to_train_on=1, + per_device_batch_size=1, + data_sharding=("data",), + activations_dtype=jnp.float32, + ) + self.max_pipeline.ref_image_processor = MaxVaeImageProcessor( + vae_scale_factor=8, + spatial_patch_size=(2, 2), + resize_mode="fill", + fill_color=0, + ) + self.max_pipeline.video_processor = MaxVideoProcessor(vae_scale_factor=8) + self.max_pipeline.transformer = SimpleNamespace( + config=SimpleNamespace(patch_size=(1, 2, 2)), + motion_encoder=SimpleNamespace(size=16), + ) + + class FakeScheduler: + + def set_timesteps(self, state, num_inference_steps, shape): + del state, shape + return SimpleNamespace(timesteps=jnp.arange(num_inference_steps, 0, -1, dtype=jnp.int32)) + + def step(self, state, noise_pred, t, sample, return_dict=False): + del t, return_dict + return sample - noise_pred * 0.1, state + + self.max_pipeline.scheduler = FakeScheduler() + self.max_pipeline.scheduler_state = SimpleNamespace() + + def test_encode_prompt_matches_diffusers(self): + prompt = [" Hello world ", "test & check"] + negative_prompt = ["bad motion", "low detail"] + + max_prompt, max_negative = MaxWanPipeline.encode_prompt( + self.max_pipeline, + prompt=prompt, + negative_prompt=negative_prompt, + num_videos_per_prompt=2, + max_sequence_length=8, + ) + hf_prompt, hf_negative = HFWanAnimatePipeline.encode_prompt( + self.hf_pipeline, + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=True, + num_videos_per_prompt=2, + max_sequence_length=8, + device=torch.device("cpu"), + ) + + np.testing.assert_allclose(to_numpy(max_prompt), to_numpy(hf_prompt), atol=0.0, rtol=0.0) + np.testing.assert_allclose(to_numpy(max_negative), to_numpy(hf_negative), atol=0.0, rtol=0.0) + + def test_encode_image_matches_diffusers_call_semantics(self): + image = self._random_rgb_image(17, 19) + + max_image = MaxWanPipeline.encode_image(self.max_pipeline, image, num_videos_per_prompt=3) + hf_image = HFWanAnimatePipeline.encode_image(self.hf_pipeline, image, device=torch.device("cpu")).repeat(3, 1, 1) + + np.testing.assert_allclose(to_numpy(max_image), to_numpy(hf_image), atol=0.0, rtol=0.0) + + def test_pad_video_frames_matches_diffusers(self): + frames = [1, 2, 3, 4, 5] + + max_frames = MaxWanAnimatePipeline.pad_video_frames(frames, 10) + hf_frames = HFWanAnimatePipeline.pad_video_frames(self.hf_pipeline, frames, 10) + + self.assertEqual(max_frames, hf_frames) + + def test_prepare_reference_image_latents_matches_diffusers(self): + image = self._random_torch_tensor((1, 3, 16, 16)) + + max_latents = MaxWanAnimatePipeline.prepare_reference_image_latents( + self.max_pipeline, jnp.array(image.numpy()), batch_size=2, dtype=jnp.float32 + ) + hf_latents = HFWanAnimatePipeline.prepare_reference_image_latents( + self.hf_pipeline, image, batch_size=2, dtype=torch.float32, device=torch.device("cpu") + ) + + np.testing.assert_allclose(to_numpy(max_latents), hf_channel_first_to_last(hf_latents), atol=0.0, rtol=0.0) + + def test_prepare_pose_latents_matches_diffusers(self): + pose_video = self._random_torch_tensor((1, 3, 9, 16, 16)) + + max_latents = MaxWanAnimatePipeline.prepare_pose_latents( + self.max_pipeline, jnp.array(pose_video.numpy()), batch_size=2, dtype=jnp.float32 + ) + hf_latents = HFWanAnimatePipeline.prepare_pose_latents( + self.hf_pipeline, pose_video, batch_size=2, dtype=torch.float32, device=torch.device("cpu") + ) + + np.testing.assert_allclose(to_numpy(max_latents), hf_channel_first_to_last(hf_latents), atol=0.0, rtol=0.0) + + def test_prepare_segment_latents_matches_diffusers_when_latents_are_provided(self): + max_input = self._random_jax_array((1, 4, 2, 2, 2)) + hf_input = torch.tensor(np.transpose(to_numpy(max_input), (0, 4, 1, 2, 3))) + + max_latents = MaxWanAnimatePipeline.prepare_segment_latents( + self.max_pipeline, + batch_size=1, + height=16, + width=16, + segment_frame_length=9, + dtype=jnp.bfloat16, + rng=jnp.array([0, 1], dtype=jnp.uint32), + latents=max_input, + ) + hf_latents = HFWanAnimatePipeline.prepare_latents( + self.hf_pipeline, + batch_size=1, + num_channels_latents=2, + height=16, + width=16, + num_frames=9, + dtype=torch.bfloat16, + device=torch.device("cpu"), + latents=hf_input, + ) + + np.testing.assert_allclose(to_numpy(max_latents), hf_channel_first_to_last(hf_latents), atol=0.0, rtol=0.0) + + def test_prepare_segment_latents_samples_expected_shape_dtype_and_values(self): + rng = jax.random.key(7) + + actual = MaxWanAnimatePipeline.prepare_segment_latents( + self.max_pipeline, + batch_size=2, + height=16, + width=16, + segment_frame_length=9, + dtype=jnp.bfloat16, + rng=rng, + ) + expected = jax.random.normal(rng, shape=(2, 4, 2, 2, 2), dtype=jnp.float32).astype(jnp.bfloat16) + + self.assertEqual(actual.shape, (2, 4, 2, 2, 2)) + self.assertEqual(actual.dtype, jnp.bfloat16) + np.testing.assert_array_equal(to_numpy(actual), to_numpy(expected)) + + def test_get_i2v_mask_constructs_expected_temporal_layout(self): + mask_pixel_values = jnp.arange(5, dtype=jnp.float32).reshape(1, 1, 5, 1, 1) + + actual = MaxWanAnimatePipeline.get_i2v_mask( + self.max_pipeline, + batch_size=1, + latent_t=2, + latent_h=1, + latent_w=1, + mask_len=1, + mask_pixel_values=mask_pixel_values, + dtype=jnp.float32, + ) + expected = jnp.array([[[[[1.0, 1.0, 1.0, 1.0]]], [[[1.0, 2.0, 3.0, 4.0]]]]], dtype=jnp.float32) + + np.testing.assert_allclose(to_numpy(actual), to_numpy(expected), atol=0.0, rtol=0.0) + + def test_prepare_prev_segment_cond_latents_matches_diffusers_for_animate(self): + prev_segment = self._random_torch_tensor((1, 3, 1, 16, 16)) + + max_latents = MaxWanAnimatePipeline.prepare_prev_segment_cond_latents( + self.max_pipeline, + prev_segment_cond_video=jnp.array(prev_segment.numpy()), + background_video=None, + mask_video=None, + batch_size=1, + segment_frame_length=9, + start_frame=4, + height=16, + width=16, + prev_segment_cond_frames=1, + task="animate", + dtype=jnp.float32, + ) + hf_latents = HFWanAnimatePipeline.prepare_prev_segment_cond_latents( + self.hf_pipeline, + prev_segment_cond_video=prev_segment, + background_video=None, + mask_video=None, + batch_size=1, + segment_frame_length=9, + start_frame=4, + height=16, + width=16, + prev_segment_cond_frames=1, + task="animate", + dtype=torch.float32, + device=torch.device("cpu"), + ) + + np.testing.assert_allclose(to_numpy(max_latents), hf_channel_first_to_last(hf_latents), atol=0.0, rtol=0.0) + + def test_prepare_prev_segment_cond_latents_animate_encodes_full_segment_like_diffusers(self): + call_lengths = [] + + def fake_encode(video, dtype): + del dtype + call_lengths.append(video.shape[2]) + latent_t = (video.shape[2] - 1) // self.max_pipeline.vae_scale_factor_temporal + 1 + latent_h = video.shape[3] // self.max_pipeline.vae_scale_factor_spatial + latent_w = video.shape[4] // self.max_pipeline.vae_scale_factor_spatial + return jnp.ones((video.shape[0], latent_t, latent_h, latent_w, self.max_pipeline.vae.z_dim), dtype=jnp.float32) + + self.max_pipeline._encode_video_to_latents = fake_encode + prev_segment = jnp.ones((1, 3, 1, 16, 16), dtype=jnp.float32) + + _ = MaxWanAnimatePipeline.prepare_prev_segment_cond_latents( + self.max_pipeline, + prev_segment_cond_video=prev_segment, + background_video=None, + mask_video=None, + batch_size=1, + segment_frame_length=9, + start_frame=4, + height=16, + width=16, + prev_segment_cond_frames=1, + task="animate", + dtype=jnp.float32, + ) + + self.assertEqual(call_lengths, [9]) + + def test_prepare_prev_segment_cond_latents_animate_first_segment_encodes_zero_filled_segment_like_diffusers(self): + call_lengths = [] + + def fake_encode(video, dtype): + del dtype + call_lengths.append(video.shape[2]) + latent_t = (video.shape[2] - 1) // self.max_pipeline.vae_scale_factor_temporal + 1 + latent_h = video.shape[3] // self.max_pipeline.vae_scale_factor_spatial + latent_w = video.shape[4] // self.max_pipeline.vae_scale_factor_spatial + return jnp.zeros( + (video.shape[0], latent_t, latent_h, latent_w, self.max_pipeline.vae.z_dim), dtype=jnp.float32 + ) + + self.max_pipeline._encode_video_to_latents = fake_encode + + _ = MaxWanAnimatePipeline.prepare_prev_segment_cond_latents( + self.max_pipeline, + prev_segment_cond_video=None, + background_video=None, + mask_video=None, + batch_size=1, + segment_frame_length=9, + start_frame=0, + height=16, + width=16, + prev_segment_cond_frames=1, + task="animate", + dtype=jnp.float32, + ) + + self.assertEqual(call_lengths, [9]) + + def test_prepare_prev_segment_cond_latents_matches_diffusers_for_replace(self): + prev_segment = self._random_torch_tensor((1, 3, 1, 16, 16)) + background = self._random_torch_tensor((1, 3, 9, 16, 16)) + mask = self._random_mask_tensor((1, 1, 9, 16, 16)) + + max_latents = MaxWanAnimatePipeline.prepare_prev_segment_cond_latents( + self.max_pipeline, + prev_segment_cond_video=jnp.array(prev_segment.numpy()), + background_video=jnp.array(background.numpy()), + mask_video=jnp.array(mask.numpy()), + batch_size=1, + segment_frame_length=9, + start_frame=4, + height=16, + width=16, + prev_segment_cond_frames=1, + task="replace", + dtype=jnp.float32, + ) + hf_latents = HFWanAnimatePipeline.prepare_prev_segment_cond_latents( + self.hf_pipeline, + prev_segment_cond_video=prev_segment, + background_video=background, + mask_video=mask, + batch_size=1, + segment_frame_length=9, + start_frame=4, + height=16, + width=16, + prev_segment_cond_frames=1, + task="replace", + dtype=torch.float32, + device=torch.device("cpu"), + ) + + np.testing.assert_allclose(to_numpy(max_latents), hf_channel_first_to_last(hf_latents), atol=0.0, rtol=0.0) + + def test_resize_mask_to_latent_spatial_matches_torch_nearest(self): + mask = self._random_mask_tensor((1, 1, 9, 16, 16)) + hf_mask = mask.permute(0, 2, 1, 3, 4).flatten(0, 1) + hf_mask = F.interpolate(hf_mask, size=(2, 2), mode="nearest") + hf_mask = hf_mask.unflatten(0, (1, -1)).permute(0, 2, 1, 3, 4) + + max_mask = MaxWanAnimatePipeline._resize_mask_to_latent_spatial(self.max_pipeline, jnp.array(mask.numpy()), 2, 2) + + np.testing.assert_allclose(to_numpy(max_mask), to_numpy(hf_mask), atol=0.0, rtol=0.0) + + def test_reference_image_processor_matches_diffusers_fill_resize(self): + image = self._random_rgb_image(6, 10) + max_processor = MaxVaeImageProcessor( + vae_scale_factor=8, + spatial_patch_size=(2, 2), + resize_mode="fill", + fill_color=0, + ) + hf_processor = HFWanAnimateImageProcessor(vae_scale_factor=8, spatial_patch_size=(2, 2), fill_color=0) + + max_image = max_processor.preprocess(image, height=16, width=16) + hf_image = hf_processor.preprocess(image, height=16, width=16, resize_mode="fill") + + np.testing.assert_allclose(to_numpy(max_image), to_numpy(hf_image), atol=0.0, rtol=0.0) + + def test_video_processor_matches_diffusers(self): + frames = [self._random_rgb_image(9, 13) for _ in range(3)] + max_processor = MaxVideoProcessor(vae_scale_factor=8) + hf_processor = HFVideoProcessor(vae_scale_factor=8) + + max_video = max_processor.preprocess_video(frames, height=16, width=16) + hf_video = hf_processor.preprocess_video(frames, height=16, width=16) + + np.testing.assert_allclose(to_numpy(max_video), to_numpy(hf_video), atol=0.0, rtol=0.0) + + def test_mask_video_preprocessing_matches_diffusers(self): + masks = [self._random_mask_image(9, 13) for _ in range(3)] + + max_mask = self.max_pipeline.video_processor_for_mask.preprocess_video(masks, height=16, width=16) + hf_mask = self.hf_pipeline.video_processor_for_mask.preprocess_video(masks, height=16, width=16) + + np.testing.assert_allclose(to_numpy(max_mask), to_numpy(hf_mask), atol=0.0, rtol=0.0) + + def test_check_inputs_matches_diffusers_validation(self): + invalid_calls = [ + dict( + prompt="prompt", + negative_prompt=None, + image=PIL.Image.new("RGB", (16, 16)), + pose_video=[PIL.Image.new("RGB", (16, 16))], + face_video=[PIL.Image.new("RGB", (16, 16))], + background_video=None, + mask_video=None, + height=16, + width=16, + prompt_embeds=jnp.zeros((1, 1, 1)), + negative_prompt_embeds=None, + image_embeds=None, + mode="animate", + prev_segment_conditioning_frames=1, + ), + dict( + prompt="prompt", + negative_prompt=None, + image=PIL.Image.new("RGB", (16, 16)), + pose_video=[PIL.Image.new("RGB", (16, 16))], + face_video=[PIL.Image.new("RGB", (16, 16))], + background_video=None, + mask_video=None, + height=18, + width=16, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + mode="animate", + prev_segment_conditioning_frames=1, + ), + dict( + prompt="prompt", + negative_prompt=None, + image=PIL.Image.new("RGB", (16, 16)), + pose_video=[PIL.Image.new("RGB", (16, 16))], + face_video=[PIL.Image.new("RGB", (16, 16))], + background_video=None, + mask_video=None, + height=16, + width=16, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + mode="replace", + prev_segment_conditioning_frames=3, + ), + ] + + for kwargs in invalid_calls: + with self.subTest(kwargs=kwargs): + with self.assertRaises(ValueError) as max_ctx: + self.max_pipeline.check_inputs(**kwargs) + with self.assertRaises(ValueError) as hf_ctx: + self.hf_pipeline.check_inputs(**kwargs) + self.assertEqual(str(max_ctx.exception), str(hf_ctx.exception)) + + def test_animate_transformer_forward_pass_matches_diffusers_layout(self): + capture = {} + + class FakeTransformer: + + def __call__( + self, + hidden_states, + timestep, + encoder_hidden_states, + encoder_hidden_states_image, + pose_hidden_states, + face_pixel_values, + motion_encode_batch_size, + return_dict, + ): + capture["hidden_states"] = hidden_states + capture["timestep"] = timestep + capture["encoder_hidden_states"] = encoder_hidden_states + capture["encoder_hidden_states_image"] = encoder_hidden_states_image + capture["pose_hidden_states"] = pose_hidden_states + capture["face_pixel_values"] = face_pixel_values + capture["motion_encode_batch_size"] = motion_encode_batch_size + capture["return_dict"] = return_dict + return (hidden_states[:, :2],) + + latents = self._random_jax_array((1, 3, 2, 2, 2)) + reference_latents = self._random_jax_array((1, 3, 2, 2, 2)) + pose_latents = self._random_jax_array((1, 3, 2, 2, 2)) + face_video = self._random_jax_array((1, 3, 9, 4, 4), low=0.0, high=1.0) + timestep = jnp.array([7], dtype=jnp.int32) + prompt_embeds = self._random_jax_array((1, 4, 3)) + image_embeds = self._random_jax_array((1, 4, 3)) + + with mock.patch("maxdiffusion.pipelines.wan.wan_pipeline_animate.nnx.merge", return_value=FakeTransformer()): + noise_pred = animate_transformer_forward_pass.__wrapped__( + graphdef=None, + sharded_state=None, + rest_of_state=None, + latents=latents, + reference_latents=reference_latents, + pose_latents=pose_latents, + face_video_segment=face_video, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + motion_encode_batch_size=7, + ) + + expected_hidden = jnp.transpose(jnp.concatenate([latents, reference_latents], axis=-1), (0, 4, 1, 2, 3)) + expected_pose = jnp.transpose(pose_latents, (0, 4, 1, 2, 3)) + + np.testing.assert_allclose(to_numpy(capture["hidden_states"]), to_numpy(expected_hidden), atol=0.0, rtol=0.0) + np.testing.assert_allclose(to_numpy(capture["pose_hidden_states"]), to_numpy(expected_pose), atol=0.0, rtol=0.0) + np.testing.assert_allclose(to_numpy(capture["face_pixel_values"]), to_numpy(face_video), atol=0.0, rtol=0.0) + np.testing.assert_allclose(to_numpy(capture["encoder_hidden_states"]), to_numpy(prompt_embeds), atol=0.0, rtol=0.0) + np.testing.assert_allclose( + to_numpy(capture["encoder_hidden_states_image"]), to_numpy(image_embeds), atol=0.0, rtol=0.0 + ) + self.assertEqual(capture["motion_encode_batch_size"], 7) + self.assertFalse(capture["return_dict"]) + np.testing.assert_allclose(to_numpy(noise_pred), to_numpy(latents), atol=0.0, rtol=0.0) + + def test_single_denoising_step_matches_diffusers_with_cfg(self): + class FakeDenoiseTransformer: + + def __call__( + self, + hidden_states, + timestep, + encoder_hidden_states, + encoder_hidden_states_image, + pose_hidden_states, + face_pixel_values, + motion_encode_batch_size, + return_dict, + ): + del motion_encode_batch_size, return_dict + + def _scalar(x): + if isinstance(x, torch.Tensor): + return x.float().mean(dim=tuple(range(1, x.ndim))).view(-1, 1, 1, 1, 1) + return jnp.mean(x.astype(jnp.float32), axis=tuple(range(1, x.ndim))).reshape((-1, 1, 1, 1, 1)) + + noise = ( + hidden_states[:, :2] * 0.5 + + pose_hidden_states[:, :2] * 0.1 + + _scalar(encoder_hidden_states) * 0.01 + + _scalar(encoder_hidden_states_image) * 0.02 + + _scalar(face_pixel_values) * 0.03 + + _scalar(timestep) * 0.001 + ) + return (noise,) + + guidance_scale = 3.0 + timestep_count = 4 + fake_transformer = FakeDenoiseTransformer() + + max_latents = self._random_jax_array((1, 3, 2, 2, 2)) + max_reference = self._random_jax_array((1, 3, 2, 2, 2)) + max_pose = self._random_jax_array((1, 3, 2, 2, 2)) + max_face = self._random_jax_array((1, 3, 9, 4, 4), low=0.0, high=1.0) + max_prompt = self._random_jax_array((1, 4, 3)) + max_negative = self._random_jax_array((1, 4, 3)) + max_image = self._random_jax_array((1, 4, 3)) + + hf_latents = torch.tensor(np.transpose(to_numpy(max_latents), (0, 4, 1, 2, 3))) + hf_reference = torch.tensor(np.transpose(to_numpy(max_reference), (0, 4, 1, 2, 3))) + hf_pose = torch.tensor(np.transpose(to_numpy(max_pose), (0, 4, 1, 2, 3))) + hf_face = torch.tensor(to_numpy(max_face)) + hf_prompt = torch.tensor(to_numpy(max_prompt)) + hf_negative = torch.tensor(to_numpy(max_negative)) + hf_image = torch.tensor(to_numpy(max_image)) + + scheduler_config = dict(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=5.0) + max_scheduler = FlaxUniPCMultistepScheduler(**scheduler_config) + max_state = max_scheduler.create_state() + max_state = max_scheduler.set_timesteps(max_state, num_inference_steps=timestep_count, shape=max_latents.shape) + + hf_scheduler = UniPCMultistepScheduler(**scheduler_config) + hf_scheduler.set_timesteps(timestep_count, device="cpu") + + timestep = int(to_numpy(hf_scheduler.timesteps[0])) + max_timestep = jnp.full((max_latents.shape[0],), timestep, dtype=jnp.int32) + hf_timestep = torch.full((hf_latents.shape[0],), timestep, dtype=torch.int64) + + with mock.patch("maxdiffusion.pipelines.wan.wan_pipeline_animate.nnx.merge", return_value=fake_transformer): + max_noise_cond = animate_transformer_forward_pass.__wrapped__( + graphdef=None, + sharded_state=None, + rest_of_state=None, + latents=max_latents, + reference_latents=max_reference, + pose_latents=max_pose, + face_video_segment=max_face, + timestep=max_timestep, + encoder_hidden_states=max_prompt, + encoder_hidden_states_image=max_image, + motion_encode_batch_size=5, + ) + max_noise_uncond = animate_transformer_forward_pass.__wrapped__( + graphdef=None, + sharded_state=None, + rest_of_state=None, + latents=max_latents, + reference_latents=max_reference, + pose_latents=max_pose, + face_video_segment=max_face * 0 - 1, + timestep=max_timestep, + encoder_hidden_states=max_negative, + encoder_hidden_states_image=max_image, + motion_encode_batch_size=5, + ) + + max_noise = max_noise_uncond + guidance_scale * (max_noise_cond - max_noise_uncond) + + hf_latent_model_input = torch.cat([hf_latents, hf_reference], dim=1) + hf_noise_cond = fake_transformer( + hidden_states=hf_latent_model_input, + timestep=hf_timestep, + encoder_hidden_states=hf_prompt, + encoder_hidden_states_image=hf_image, + pose_hidden_states=hf_pose, + face_pixel_values=hf_face, + motion_encode_batch_size=5, + return_dict=False, + )[0] + hf_noise_uncond = fake_transformer( + hidden_states=hf_latent_model_input, + timestep=hf_timestep, + encoder_hidden_states=hf_negative, + encoder_hidden_states_image=hf_image, + pose_hidden_states=hf_pose, + face_pixel_values=hf_face * 0 - 1, + motion_encode_batch_size=5, + return_dict=False, + )[0] + hf_noise = hf_noise_uncond + guidance_scale * (hf_noise_cond - hf_noise_uncond) + + max_next, _ = max_scheduler.step(max_state, max_noise, timestep, max_latents, return_dict=False) + hf_next = hf_scheduler.step(hf_noise, timestep, hf_latents, return_dict=False)[0] + + np.testing.assert_allclose(to_numpy(max_noise), hf_channel_first_to_last(hf_noise), atol=1e-6, rtol=1e-6) + np.testing.assert_allclose(to_numpy(max_next), hf_channel_first_to_last(hf_next), atol=1e-5, rtol=1e-5) + + def test_flax_unipc_flow_sigmas_match_diffusers(self): + scheduler_config = dict(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=5.0) + + max_scheduler = FlaxUniPCMultistepScheduler(**scheduler_config) + max_state = max_scheduler.create_state() + max_state = max_scheduler.set_timesteps(max_state, num_inference_steps=4, shape=(1, 2, 3, 4, 5)) + + hf_scheduler = UniPCMultistepScheduler(**scheduler_config) + hf_scheduler.set_timesteps(4, device="cpu") + + np.testing.assert_array_equal(to_numpy(max_state.timesteps), to_numpy(hf_scheduler.timesteps)) + np.testing.assert_allclose(to_numpy(max_state.sigmas), to_numpy(hf_scheduler.sigmas), atol=1e-7, rtol=0.0) + + max_sample = self._random_jax_array((1, 2, 3, 4, 5)) + hf_sample = torch.tensor(to_numpy(max_sample)) + + for timestep in to_numpy(hf_scheduler.timesteps[:3]): + hf_model_output = self._random_torch_tensor(tuple(hf_sample.shape), low=-0.25, high=0.25) + max_model_output = jnp.array(to_numpy(hf_model_output)) + + hf_sample = hf_scheduler.step(hf_model_output, int(timestep), hf_sample, return_dict=False)[0] + max_sample, max_state = max_scheduler.step( + max_state, max_model_output, int(timestep), max_sample, return_dict=False + ) + + np.testing.assert_allclose(to_numpy(max_sample), to_numpy(hf_sample), atol=1e-4, rtol=1e-5) + + def test_call_runs_multisegment_pipeline_and_trims_output(self): + self._configure_pipeline_for_call_test() + image = self._random_rgb_image(16, 16) + pose_video = [self._random_rgb_image(16, 16) for _ in range(10)] + face_video = [self._random_rgb_image(16, 16) for _ in range(10)] + decode_shapes = [] + denoise_shapes = [] + + def fake_denoise( + graphdef, + sharded_state, + rest_of_state, + latents, + reference_latents, + pose_latents, + face_video_segment, + timestep, + encoder_hidden_states, + encoder_hidden_states_image, + motion_encode_batch_size=None, + ): + del graphdef, sharded_state, rest_of_state, reference_latents, encoder_hidden_states, encoder_hidden_states_image + del motion_encode_batch_size + denoise_shapes.append((latents.shape, pose_latents.shape, face_video_segment.shape, timestep.shape)) + return jnp.ones_like(latents) * 0.25 + + def fake_decode(latents_cl): + decode_shapes.append(latents_cl.shape) + batch, latent_t, latent_h, latent_w, _ = latents_cl.shape + pixel_t = (latent_t - 1) * self.max_pipeline.vae_scale_factor_temporal + 1 + height = latent_h * self.max_pipeline.vae_scale_factor_spatial + width = latent_w * self.max_pipeline.vae_scale_factor_spatial + return self._random_jax_array((batch, 3, pixel_t, height, width)) + + with mock.patch("maxdiffusion.pipelines.wan.wan_pipeline_animate.nnx.split", return_value=(None, None, None)): + with mock.patch( + "maxdiffusion.pipelines.wan.wan_pipeline_animate.animate_transformer_forward_pass", side_effect=fake_denoise + ): + with mock.patch.object(self.max_pipeline, "_decode_segment_to_pixels", side_effect=fake_decode): + output = self.max_pipeline( + image=image, + pose_video=pose_video, + face_video=face_video, + prompt="animate this", + negative_prompt=None, + height=16, + width=16, + segment_frame_length=10, + num_inference_steps=2, + mode="animate", + prev_segment_conditioning_frames=1, + guidance_scale=1.0, + output_type="np", + ) + + self.assertEqual(output.shape, (1, 10, 16, 16, 3)) + self.assertEqual(len(decode_shapes), 2) + self.assertEqual(decode_shapes[0], (1, 3, 2, 2, 2)) + self.assertEqual(len(denoise_shapes), 4) + self.assertEqual(denoise_shapes[0], ((1, 4, 2, 2, 2), (1, 3, 2, 2, 2), (1, 3, 9, 16, 16), (1,))) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/maxdiffusion/tests/wan_animate_module_parity_test.py b/src/maxdiffusion/tests/wan_animate_module_parity_test.py new file mode 100644 index 00000000..d0d74d3b --- /dev/null +++ b/src/maxdiffusion/tests/wan_animate_module_parity_test.py @@ -0,0 +1,534 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import unittest +from importlib import resources + +os.environ["JAX_PLATFORMS"] = "cpu" +os.environ.setdefault("MPLCONFIGDIR", "/tmp/mplconfig") + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +import torch +from flax import nnx +from flax.linen import partitioning as nn_partitioning +from flax.traverse_util import flatten_dict +from jax.sharding import Mesh + +jax.config.update("jax_platforms", "cpu") + +from diffusers.models.transformers.transformer_wan_animate import ( + FusedLeakyReLU as HFFusedLeakyReLU, + MotionConv2d as HFMotionConv2d, + MotionEncoderResBlock as HFMotionEncoderResBlock, + MotionLinear as HFMotionLinear, + WanAnimateFaceBlockCrossAttention as HFWanAnimateFaceBlockCrossAttention, + WanAnimateFaceEncoder as HFWanAnimateFaceEncoder, + WanAnimateMotionEncoder as HFWanAnimateMotionEncoder, + WanAnimateTransformer3DModel as HFWanAnimateTransformer3DModel, +) + +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import create_device_mesh +from maxdiffusion.models.wan.transformers.transformer_wan_animate import ( + FusedLeakyReLU, + MotionConv2d, + MotionEncoderResBlock, + MotionLinear, + NNXWanAnimateTransformer3DModel, + WanAnimateFaceBlockCrossAttention, + WanAnimateFaceEncoder, + WanAnimateMotionEncoder, +) +from maxdiffusion.models.wan.wan_utils import ( + _rename_wan_animate_pt_tuple_key, + get_wan_animate_key_and_value, +) + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + + +def to_numpy(array): + if isinstance(array, torch.Tensor): + if array.dtype == torch.bfloat16: + array = array.float() + return array.detach().cpu().numpy() + return np.asarray(array) + + +def assert_allclose(test_case, actual, expected, *, atol=1e-6, rtol=1e-6): + test_case.assertEqual(to_numpy(actual).shape, to_numpy(expected).shape) + np.testing.assert_allclose(to_numpy(actual), to_numpy(expected), atol=atol, rtol=rtol) + + +def copy_fused_leaky_relu_params(max_module, hf_module): + if max_module.bias is not None: + max_module.bias[...] = jnp.asarray(to_numpy(hf_module.bias)) + + +def copy_motion_conv2d_params(max_module, hf_module): + max_module.weight[...] = jnp.asarray(to_numpy(hf_module.weight)) + if max_module.bias is not None and hf_module.bias is not None: + max_module.bias[...] = jnp.asarray(to_numpy(hf_module.bias)) + if max_module.act_fn is not None and hf_module.act_fn is not None: + copy_fused_leaky_relu_params(max_module.act_fn, hf_module.act_fn) + + +def copy_motion_linear_params(max_module, hf_module): + max_module.weight[...] = jnp.asarray(to_numpy(hf_module.weight)) + if max_module.bias is not None and hf_module.bias is not None: + max_module.bias[...] = jnp.asarray(to_numpy(hf_module.bias)) + if max_module.act_fn is not None and hf_module.act_fn is not None: + copy_fused_leaky_relu_params(max_module.act_fn, hf_module.act_fn) + + +def copy_motion_encoder_resblock_params(max_module, hf_module): + copy_motion_conv2d_params(max_module.conv1, hf_module.conv1) + copy_motion_conv2d_params(max_module.conv2, hf_module.conv2) + copy_motion_conv2d_params(max_module.conv_skip, hf_module.conv_skip) + + +def copy_motion_encoder_params(max_module, hf_module): + copy_motion_conv2d_params(max_module.conv_in, hf_module.conv_in) + copy_motion_conv2d_params(max_module.conv_out, hf_module.conv_out) + for max_block, hf_block in zip(max_module.res_blocks, hf_module.res_blocks): + copy_motion_encoder_resblock_params(max_block, hf_block) + for max_linear, hf_linear in zip(max_module.motion_network, hf_module.motion_network): + copy_motion_linear_params(max_linear, hf_linear) + max_module.motion_synthesis_weight[...] = jnp.asarray(to_numpy(hf_module.motion_synthesis_weight)) + + +def copy_face_encoder_params(max_module, hf_module): + max_module.conv1_local.kernel[...] = jnp.asarray(np.transpose(to_numpy(hf_module.conv1_local.weight), (2, 1, 0))) + max_module.conv1_local.bias[...] = jnp.asarray(to_numpy(hf_module.conv1_local.bias)) + max_module.conv2.kernel[...] = jnp.asarray(np.transpose(to_numpy(hf_module.conv2.weight), (2, 1, 0))) + max_module.conv2.bias[...] = jnp.asarray(to_numpy(hf_module.conv2.bias)) + max_module.conv3.kernel[...] = jnp.asarray(np.transpose(to_numpy(hf_module.conv3.weight), (2, 1, 0))) + max_module.conv3.bias[...] = jnp.asarray(to_numpy(hf_module.conv3.bias)) + max_module.out_proj.kernel[...] = jnp.asarray(to_numpy(hf_module.out_proj.weight).T) + max_module.out_proj.bias[...] = jnp.asarray(to_numpy(hf_module.out_proj.bias)) + max_module.padding_tokens[...] = jnp.asarray(to_numpy(hf_module.padding_tokens)) + + +def copy_face_block_cross_attention_params(max_module, hf_module): + max_module.to_q.kernel[...] = jnp.asarray(to_numpy(hf_module.to_q.weight).T) + max_module.to_q.bias[...] = jnp.asarray(to_numpy(hf_module.to_q.bias)) + max_module.to_k.kernel[...] = jnp.asarray(to_numpy(hf_module.to_k.weight).T) + max_module.to_k.bias[...] = jnp.asarray(to_numpy(hf_module.to_k.bias)) + max_module.to_v.kernel[...] = jnp.asarray(to_numpy(hf_module.to_v.weight).T) + max_module.to_v.bias[...] = jnp.asarray(to_numpy(hf_module.to_v.bias)) + max_module.to_out.kernel[...] = jnp.asarray(to_numpy(hf_module.to_out.weight).T) + max_module.to_out.bias[...] = jnp.asarray(to_numpy(hf_module.to_out.bias)) + max_module.norm_q.scale[...] = jnp.asarray(to_numpy(hf_module.norm_q.weight)) + max_module.norm_k.scale[...] = jnp.asarray(to_numpy(hf_module.norm_k.weight)) + + +def map_hf_wan_animate_state_to_local(max_model, hf_model, num_layers, scan_layers=False): + state = nnx.state(max_model) + flat_vars = dict(nnx.to_flat_state(state)) + random_flax_state_dict = { + tuple(str(item) for item in key): value for key, value in flatten_dict(state.to_pure_dict()).items() + } + flax_state_dict = {} + + for pt_key, tensor in hf_model.state_dict().items(): + if "norm_added_q" in pt_key: + continue + + pt_tuple_key, is_motion_custom_weight = _rename_wan_animate_pt_tuple_key(pt_key) + flax_key, flax_tensor = get_wan_animate_key_and_value( + pt_tuple_key, + jnp.asarray(to_numpy(tensor)), + flax_state_dict, + random_flax_state_dict, + scan_layers, + is_motion_custom_weight=is_motion_custom_weight, + num_layers=num_layers, + ) + + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + + missing_keys = [key for key in flax_state_dict if key not in flat_vars] + for key, value in flax_state_dict.items(): + if key in flat_vars: + flat_vars[key][...] = value + + return missing_keys, flax_state_dict + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run WAN parity tests on Github Actions") +class WanAnimateModuleParityTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + with resources.as_file(resources.files("maxdiffusion.configs").joinpath("base_wan_14b.yml")) as config_path: + pyconfig.initialize([None, os.fspath(config_path)], unittest=True) + config = pyconfig.config + cls.logical_axis_rules = config.logical_axis_rules + cls.mesh = Mesh(create_device_mesh(config), config.mesh_axes) + + def setUp(self): + torch.manual_seed(0) + self.rngs = nnx.Rngs(jax.random.key(0)) + + def test_fused_leaky_relu_parity(self): + hf_module = HFFusedLeakyReLU(bias_channels=3).eval() + max_module = FusedLeakyReLU(rngs=self.rngs, bias_channels=3) + copy_fused_leaky_relu_params(max_module, hf_module) + + inputs = torch.randn(2, 3, 4, 5) + expected = hf_module(inputs) + actual = max_module(jnp.asarray(to_numpy(inputs))) + + assert_allclose(self, actual, expected, atol=0.0, rtol=0.0) + + def test_motion_conv2d_parity(self): + hf_module = HFMotionConv2d(3, 5, kernel_size=3, stride=2, padding=0, blur_kernel=(1, 3, 3, 1)).eval() + max_module = MotionConv2d( + rngs=self.rngs, + in_channels=3, + out_channels=5, + kernel_size=3, + stride=2, + padding=0, + blur_kernel=(1, 3, 3, 1), + ) + copy_motion_conv2d_params(max_module, hf_module) + + inputs = torch.randn(2, 3, 8, 8) + expected = hf_module(inputs) + actual = max_module(jnp.asarray(to_numpy(inputs))) + + assert_allclose(self, actual, expected, atol=2e-7, rtol=2e-6) + + def test_motion_linear_parity(self): + hf_module = HFMotionLinear(7, 5, use_activation=True).eval() + max_module = MotionLinear(rngs=self.rngs, in_dim=7, out_dim=5, use_activation=True) + copy_motion_linear_params(max_module, hf_module) + + inputs = torch.randn(4, 7) + expected = hf_module(inputs) + actual = max_module(jnp.asarray(to_numpy(inputs))) + + assert_allclose(self, actual, expected, atol=1e-7, rtol=1e-7) + + def test_motion_encoder_resblock_parity(self): + hf_module = HFMotionEncoderResBlock(8, 10).eval() + max_module = MotionEncoderResBlock(rngs=self.rngs, in_channels=8, out_channels=10) + copy_motion_encoder_resblock_params(max_module, hf_module) + + inputs = torch.randn(2, 8, 8, 8) + expected = hf_module(inputs) + actual = max_module(jnp.asarray(to_numpy(inputs))) + + assert_allclose(self, actual, expected, atol=2e-7, rtol=1e-6) + + def test_motion_encoder_parity(self): + cfg = { + "size": 4, + "style_dim": 8, + "motion_dim": 4, + "out_dim": 8, + "motion_blocks": 3, + "channels": {"4": 8, "8": 8, "16": 8}, + } + hf_module = HFWanAnimateMotionEncoder(**cfg).eval() + max_module = WanAnimateMotionEncoder(rngs=self.rngs, **cfg) + copy_motion_encoder_params(max_module, hf_module) + + inputs = torch.randn(3, 3, 4, 4) + expected = hf_module(inputs) + actual = max_module(jnp.asarray(to_numpy(inputs))) + + assert_allclose(self, actual, expected, atol=5e-7, rtol=1e-6) + + def test_face_encoder_parity(self): + hf_module = HFWanAnimateFaceEncoder(in_dim=8, out_dim=12, hidden_dim=16, num_heads=2).eval() + max_module = WanAnimateFaceEncoder(rngs=self.rngs, in_dim=8, out_dim=12, hidden_dim=16, num_heads=2) + copy_face_encoder_params(max_module, hf_module) + + inputs = torch.randn(2, 7, 8) + expected = hf_module(inputs) + actual = max_module(jnp.asarray(to_numpy(inputs))) + + assert_allclose(self, actual, expected, atol=5e-7, rtol=1e-6) + + def test_face_block_cross_attention_parity(self): + hf_module = HFWanAnimateFaceBlockCrossAttention(dim=12, heads=3, dim_head=4, cross_attention_dim_head=4).eval() + max_module = WanAnimateFaceBlockCrossAttention( + rngs=self.rngs, + dim=12, + heads=3, + dim_head=4, + cross_attention_dim_head=4, + ) + copy_face_block_cross_attention_params(max_module, hf_module) + + hidden_states = torch.randn(2, 8, 12) + encoder_hidden_states = torch.randn(2, 2, 3, 12) + expected = hf_module(hidden_states, encoder_hidden_states) + actual = max_module(jnp.asarray(to_numpy(hidden_states)), jnp.asarray(to_numpy(encoder_hidden_states))) + + assert_allclose(self, actual, expected, atol=2e-7, rtol=1e-6) + + def test_wan_animate_transformer_weight_mapping_covers_all_local_params(self): + cfg = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 4, + "in_channels": 12, + "latent_channels": 4, + "out_channels": 4, + "text_dim": 8, + "freq_dim": 8, + "ffn_dim": 16, + "num_layers": 1, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "eps": 1e-6, + "image_dim": 4, + "added_kv_proj_dim": None, + "rope_max_seq_len": 32, + "motion_encoder_channel_sizes": {"4": 8, "8": 8, "16": 8}, + "motion_encoder_size": 4, + "motion_style_dim": 8, + "motion_dim": 4, + "motion_encoder_dim": 8, + "face_encoder_hidden_dim": 8, + "face_encoder_num_heads": 2, + "inject_face_latents_blocks": 1, + "motion_encoder_batch_size": 2, + } + hf_model = HFWanAnimateTransformer3DModel(**cfg).eval() + + with self.mesh, nn_partitioning.axis_rules(self.logical_axis_rules): + max_model = NNXWanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=False, mesh=self.mesh, **cfg) + missing_keys, flax_state_dict = map_hf_wan_animate_state_to_local( + max_model, hf_model, num_layers=cfg["num_layers"], scan_layers=False + ) + + self.assertFalse(missing_keys, msg=f"Unmapped animate parameters: {missing_keys}") + self.assertGreater(len(flax_state_dict), 0) + + def test_wan_animate_transformer_weight_mapping_covers_all_local_params_scanned(self): + cfg = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 4, + "in_channels": 12, + "latent_channels": 4, + "out_channels": 4, + "text_dim": 8, + "freq_dim": 8, + "ffn_dim": 16, + "num_layers": 1, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "eps": 1e-6, + "image_dim": 4, + "added_kv_proj_dim": None, + "rope_max_seq_len": 32, + "motion_encoder_channel_sizes": {"4": 8, "8": 8, "16": 8}, + "motion_encoder_size": 4, + "motion_style_dim": 8, + "motion_dim": 4, + "motion_encoder_dim": 8, + "face_encoder_hidden_dim": 8, + "face_encoder_num_heads": 2, + "inject_face_latents_blocks": 1, + "motion_encoder_batch_size": 2, + } + hf_model = HFWanAnimateTransformer3DModel(**cfg).eval() + + with self.mesh, nn_partitioning.axis_rules(self.logical_axis_rules): + max_model = NNXWanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=True, mesh=self.mesh, **cfg) + missing_keys, flax_state_dict = map_hf_wan_animate_state_to_local( + max_model, hf_model, num_layers=cfg["num_layers"], scan_layers=True + ) + + self.assertFalse(missing_keys, msg=f"Unmapped animate parameters for scanned model: {missing_keys}") + self.assertGreater(len(flax_state_dict), 0) + + def test_wan_animate_transformer_block_mapping_supports_scan_layers_toggle(self): + tensor = jnp.arange(12, dtype=jnp.float32).reshape(3, 4) + pt_tuple_key = ("blocks", "1", "attn1", "to_q", "weight") + + unscanned_shapes = { + ("blocks", "1", "attn1", "query", "kernel"): jnp.zeros((4, 3), dtype=jnp.float32), + } + flax_key, flax_tensor = get_wan_animate_key_and_value( + pt_tuple_key, + tensor, + {}, + unscanned_shapes, + False, + num_layers=2, + ) + self.assertEqual(flax_key, ("blocks", 1, "attn1", "query", "kernel")) + np.testing.assert_array_equal(np.asarray(flax_tensor), np.asarray(tensor.T)) + + scanned_shapes = { + ("blocks", "attn1", "query", "kernel"): jnp.zeros((2, 4, 3), dtype=jnp.float32), + } + flax_key, flax_tensor = get_wan_animate_key_and_value( + pt_tuple_key, + tensor, + {}, + scanned_shapes, + True, + num_layers=2, + ) + self.assertEqual(flax_key, ("blocks", "attn1", "query", "kernel")) + expected = np.zeros((2, 4, 3), dtype=np.float32) + expected[1] = np.asarray(tensor.T) + np.testing.assert_array_equal(np.asarray(flax_tensor), expected) + + def test_wan_animate_transformer_forward_parity(self): + cfg = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 4, + "in_channels": 12, + "latent_channels": 4, + "out_channels": 4, + "text_dim": 8, + "freq_dim": 8, + "ffn_dim": 16, + "num_layers": 1, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "eps": 1e-6, + "image_dim": 4, + "added_kv_proj_dim": None, + "rope_max_seq_len": 32, + "motion_encoder_channel_sizes": {"4": 8, "8": 8, "16": 8}, + "motion_encoder_size": 4, + "motion_style_dim": 8, + "motion_dim": 4, + "motion_encoder_dim": 8, + "face_encoder_hidden_dim": 8, + "face_encoder_num_heads": 2, + "inject_face_latents_blocks": 1, + "motion_encoder_batch_size": 2, + } + hf_model = HFWanAnimateTransformer3DModel(**cfg).eval() + + with self.mesh, nn_partitioning.axis_rules(self.logical_axis_rules): + max_model = NNXWanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=False, mesh=self.mesh, **cfg) + missing_keys, _ = map_hf_wan_animate_state_to_local( + max_model, hf_model, num_layers=cfg["num_layers"], scan_layers=False + ) + self.assertFalse(missing_keys, msg=f"Unmapped animate parameters: {missing_keys}") + + hidden_states = torch.randn(1, 12, 3, 4, 4) + pose_hidden_states = torch.randn(1, 4, 2, 4, 4) + encoder_hidden_states = torch.randn(1, 5, 8) + encoder_hidden_states_image = torch.randn(1, 3, 4) + face_pixel_values = torch.randn(1, 3, 2, 4, 4) + timestep = torch.tensor([7], dtype=torch.long) + + with torch.no_grad(): + expected = hf_model( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=encoder_hidden_states_image, + pose_hidden_states=pose_hidden_states, + face_pixel_values=face_pixel_values, + ).sample + + actual = max_model( + hidden_states=jnp.asarray(to_numpy(hidden_states)), + timestep=jnp.asarray(to_numpy(timestep)), + encoder_hidden_states=jnp.asarray(to_numpy(encoder_hidden_states)), + encoder_hidden_states_image=jnp.asarray(to_numpy(encoder_hidden_states_image)), + pose_hidden_states=jnp.asarray(to_numpy(pose_hidden_states)), + face_pixel_values=jnp.asarray(to_numpy(face_pixel_values)), + )["sample"] + + assert_allclose(self, actual, expected, atol=5e-5, rtol=1e-5) + + def test_wan_animate_transformer_forward_parity_scanned(self): + cfg = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 4, + "in_channels": 12, + "latent_channels": 4, + "out_channels": 4, + "text_dim": 8, + "freq_dim": 8, + "ffn_dim": 16, + "num_layers": 1, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "eps": 1e-6, + "image_dim": 4, + "added_kv_proj_dim": None, + "rope_max_seq_len": 32, + "motion_encoder_channel_sizes": {"4": 8, "8": 8, "16": 8}, + "motion_encoder_size": 4, + "motion_style_dim": 8, + "motion_dim": 4, + "motion_encoder_dim": 8, + "face_encoder_hidden_dim": 8, + "face_encoder_num_heads": 2, + "inject_face_latents_blocks": 1, + "motion_encoder_batch_size": 2, + } + hf_model = HFWanAnimateTransformer3DModel(**cfg).eval() + + with self.mesh, nn_partitioning.axis_rules(self.logical_axis_rules): + max_model = NNXWanAnimateTransformer3DModel(rngs=self.rngs, scan_layers=True, mesh=self.mesh, **cfg) + missing_keys, _ = map_hf_wan_animate_state_to_local( + max_model, hf_model, num_layers=cfg["num_layers"], scan_layers=True + ) + self.assertFalse(missing_keys, msg=f"Unmapped animate parameters for scanned model: {missing_keys}") + + hidden_states = torch.randn(1, 12, 3, 4, 4) + pose_hidden_states = torch.randn(1, 4, 2, 4, 4) + encoder_hidden_states = torch.randn(1, 5, 8) + encoder_hidden_states_image = torch.randn(1, 3, 4) + face_pixel_values = torch.randn(1, 3, 2, 4, 4) + timestep = torch.tensor([7], dtype=torch.long) + + with torch.no_grad(): + expected = hf_model( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=encoder_hidden_states_image, + pose_hidden_states=pose_hidden_states, + face_pixel_values=face_pixel_values, + ).sample + + actual = max_model( + hidden_states=jnp.asarray(to_numpy(hidden_states)), + timestep=jnp.asarray(to_numpy(timestep)), + encoder_hidden_states=jnp.asarray(to_numpy(encoder_hidden_states)), + encoder_hidden_states_image=jnp.asarray(to_numpy(encoder_hidden_states_image)), + pose_hidden_states=jnp.asarray(to_numpy(pose_hidden_states)), + face_pixel_values=jnp.asarray(to_numpy(face_pixel_values)), + )["sample"] + + assert_allclose(self, actual, expected, atol=5e-5, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/maxdiffusion/tests/wan_common_module_parity_test.py b/src/maxdiffusion/tests/wan_common_module_parity_test.py new file mode 100644 index 00000000..368cf73d --- /dev/null +++ b/src/maxdiffusion/tests/wan_common_module_parity_test.py @@ -0,0 +1,424 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import unittest +from importlib import resources + +os.environ["JAX_PLATFORMS"] = "cpu" +os.environ.setdefault("MPLCONFIGDIR", "/tmp/mplconfig") + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +import torch +from flax import nnx +from flax.linen import partitioning as nn_partitioning +from flax.traverse_util import flatten_dict +from jax.sharding import Mesh + +jax.config.update("jax_platforms", "cpu") + +from diffusers.models.attention import FeedForward as HFFeedForward +from diffusers.models.transformers.transformer_wan import ( + WanImageEmbedding as HFWanImageEmbedding, + WanRotaryPosEmbed as HFWanRotaryPosEmbed, + WanTimeTextImageEmbedding as HFWanTimeTextImageEmbedding, + WanTransformer3DModel as HFWanTransformer3DModel, + WanTransformerBlock as HFWanTransformerBlock, +) + +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import create_device_mesh +from maxdiffusion.models.embeddings_flax import NNXWanImageEmbedding +from maxdiffusion.models.modeling_flax_pytorch_utils import rename_key +from maxdiffusion.models.wan.transformers.transformer_wan import ( + WanFeedForward, + WanModel, + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, +) +from maxdiffusion.models.wan.wan_utils import get_key_and_value + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" + + +def to_numpy(array): + if isinstance(array, torch.Tensor): + if array.dtype == torch.bfloat16: + array = array.float() + return array.detach().cpu().numpy() + return np.asarray(array) + + +def assert_allclose(test_case, actual, expected, *, atol=1e-6, rtol=1e-6): + test_case.assertEqual(to_numpy(actual).shape, to_numpy(expected).shape) + np.testing.assert_allclose(to_numpy(actual), to_numpy(expected), atol=atol, rtol=rtol) + + +def copy_linear_params(max_module, hf_module): + max_module.kernel[...] = jnp.asarray(to_numpy(hf_module.weight).T) + if getattr(max_module, "bias", None) is not None and hf_module.bias is not None: + max_module.bias[...] = jnp.asarray(to_numpy(hf_module.bias)) + + +def copy_fp32_layer_norm_params(max_module, hf_module): + max_module.layer_norm.scale[...] = jnp.asarray(to_numpy(hf_module.weight)) + max_module.layer_norm.bias[...] = jnp.asarray(to_numpy(hf_module.bias)) + + +def copy_wan_image_embedding_params(max_module, hf_module): + copy_fp32_layer_norm_params(max_module.norm1, hf_module.norm1) + max_module.ff.net_0.kernel[...] = jnp.asarray(to_numpy(hf_module.ff.net[0].proj.weight).T) + max_module.ff.net_0.bias[...] = jnp.asarray(to_numpy(hf_module.ff.net[0].proj.bias)) + max_module.ff.net_2.kernel[...] = jnp.asarray(to_numpy(hf_module.ff.net[2].weight).T) + max_module.ff.net_2.bias[...] = jnp.asarray(to_numpy(hf_module.ff.net[2].bias)) + copy_fp32_layer_norm_params(max_module.norm2, hf_module.norm2) + if max_module.pos_embed is not None and hf_module.pos_embed is not None: + max_module.pos_embed[...] = jnp.asarray(to_numpy(hf_module.pos_embed)) + + +def copy_wan_time_text_image_embedding_params(max_module, hf_module): + copy_linear_params(max_module.time_embedder.linear_1, hf_module.time_embedder.linear_1) + copy_linear_params(max_module.time_embedder.linear_2, hf_module.time_embedder.linear_2) + copy_linear_params(max_module.time_proj, hf_module.time_proj) + copy_linear_params(max_module.text_embedder.linear_1, hf_module.text_embedder.linear_1) + copy_linear_params(max_module.text_embedder.linear_2, hf_module.text_embedder.linear_2) + if max_module.image_embedder is not None and hf_module.image_embedder is not None: + copy_wan_image_embedding_params(max_module.image_embedder, hf_module.image_embedder) + + +def copy_wan_feed_forward_params(max_module, hf_module): + copy_linear_params(max_module.act_fn.proj, hf_module.net[0].proj) + copy_linear_params(max_module.proj_out, hf_module.net[2]) + + +def copy_wan_attention_params(max_module, hf_module): + copy_linear_params(max_module.query, hf_module.to_q) + copy_linear_params(max_module.key, hf_module.to_k) + copy_linear_params(max_module.value, hf_module.to_v) + copy_linear_params(max_module.proj_attn, hf_module.to_out[0]) + max_module.norm_q.scale[...] = jnp.asarray(to_numpy(hf_module.norm_q.weight)) + max_module.norm_k.scale[...] = jnp.asarray(to_numpy(hf_module.norm_k.weight)) + + +def copy_wan_transformer_block_params(max_module, hf_module): + max_module.adaln_scale_shift_table[...] = jnp.asarray(to_numpy(hf_module.scale_shift_table)) + copy_wan_attention_params(max_module.attn1, hf_module.attn1) + copy_wan_attention_params(max_module.attn2, hf_module.attn2) + copy_fp32_layer_norm_params(max_module.norm2, hf_module.norm2) + copy_wan_feed_forward_params(max_module.ffn, hf_module.ffn) + + +def map_hf_wan_state_to_local(max_model, hf_model, num_layers): + state = nnx.state(max_model) + flat_vars = dict(nnx.to_flat_state(state)) + random_flax_state_dict = { + tuple(str(item) for item in key): value for key, value in flatten_dict(state.to_pure_dict()).items() + } + flax_state_dict = {} + + for pt_key, tensor in hf_model.state_dict().items(): + if "norm_added_q" in pt_key: + continue + + renamed_pt_key = rename_key(pt_key) + + if "condition_embedder" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("time_embedding_0", "time_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("time_embedding_2", "time_embedder.linear_2") + renamed_pt_key = renamed_pt_key.replace("time_projection_1", "time_proj") + renamed_pt_key = renamed_pt_key.replace("text_embedding_0", "text_embedder.linear_1") + renamed_pt_key = renamed_pt_key.replace("text_embedding_2", "text_embedder.linear_2") + + if "image_embedder" in renamed_pt_key: + if "net.0.proj" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net.0.proj", "net_0") + elif "net_0.proj" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net_0.proj", "net_0") + if "net.2" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("net.2", "net_2") + renamed_pt_key = renamed_pt_key.replace("norm1", "norm1.layer_norm") + if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key: + renamed_pt_key = renamed_pt_key.replace("weight", "scale") + renamed_pt_key = renamed_pt_key.replace("kernel", "scale") + + renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") + renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table") + renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") + renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out") + renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") + renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") + + pt_tuple_key = tuple(renamed_pt_key.split(".")) + flax_key, flax_tensor = get_key_and_value( + pt_tuple_key, + to_numpy(tensor), + flax_state_dict, + random_flax_state_dict, + False, + num_layers, + ) + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + + missing_keys = [key for key in flax_state_dict if key not in flat_vars] + for key, value in flax_state_dict.items(): + if key in flat_vars: + flat_vars[key][...] = value + + return missing_keys, flax_state_dict + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run WAN parity tests on Github Actions") +class WanCommonModuleParityTest(unittest.TestCase): + + @classmethod + def setUpClass(cls): + with resources.as_file(resources.files("maxdiffusion.configs").joinpath("base_wan_14b.yml")) as config_path: + pyconfig.initialize([None, os.fspath(config_path)], unittest=True) + config = pyconfig.config + cls.logical_axis_rules = config.logical_axis_rules + cls.mesh = Mesh(create_device_mesh(config), config.mesh_axes) + + def setUp(self): + torch.manual_seed(0) + self.rngs = nnx.Rngs(jax.random.key(0)) + + def test_wan_rotary_pos_embed_parity(self): + hf_module = HFWanRotaryPosEmbed(attention_head_dim=12, patch_size=(1, 2, 2), max_seq_len=32) + max_module = WanRotaryPosEmbed(attention_head_dim=12, patch_size=(1, 2, 2), max_seq_len=32) + + hidden_states = torch.randn(1, 12, 3, 4, 4) + freqs_cos, freqs_sin = hf_module(hidden_states) + hf_complex = to_numpy(freqs_cos)[..., 0::2] + 1j * to_numpy(freqs_sin)[..., 1::2] + expected = np.transpose(hf_complex, (0, 2, 1, 3)) + + actual = max_module(jnp.asarray(np.transpose(to_numpy(hidden_states), (0, 2, 3, 4, 1)))) + + assert_allclose(self, actual, expected, atol=0.0, rtol=0.0) + + def test_wan_image_embedding_parity(self): + hf_module = HFWanImageEmbedding(4, 8, pos_embed_seq_len=None).eval() + max_module = NNXWanImageEmbedding( + rngs=self.rngs, + in_features=4, + out_features=8, + pos_embed_seq_len=None, + dtype=jnp.float32, + weights_dtype=jnp.float32, + precision=None, + flash_min_seq_length=4096, + ) + copy_wan_image_embedding_params(max_module, hf_module) + + encoder_hidden_states_image = torch.randn(2, 3, 4) + expected = hf_module(encoder_hidden_states_image) + actual, attention_mask = max_module(jnp.asarray(to_numpy(encoder_hidden_states_image))) + + self.assertIsNone(attention_mask) + assert_allclose(self, actual, expected, atol=3e-4, rtol=3e-4) + + def test_wan_time_text_image_embedding_parity(self): + hf_module = HFWanTimeTextImageEmbedding( + dim=8, + time_freq_dim=8, + time_proj_dim=48, + text_embed_dim=6, + image_embed_dim=4, + pos_embed_seq_len=None, + ).eval() + max_module = WanTimeTextImageEmbedding( + rngs=self.rngs, + dim=8, + time_freq_dim=8, + time_proj_dim=48, + text_embed_dim=6, + image_embed_dim=4, + pos_embed_seq_len=None, + flash_min_seq_length=4096, + ) + copy_wan_time_text_image_embedding_params(max_module, hf_module) + + timestep = torch.tensor([3, 7], dtype=torch.long) + encoder_hidden_states = torch.randn(2, 5, 6) + encoder_hidden_states_image = torch.randn(2, 3, 4) + + expected_temb, expected_tproj, expected_text, expected_image = hf_module( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) + actual_temb, actual_tproj, actual_text, actual_image, actual_mask = max_module( + jnp.asarray(to_numpy(timestep)), + jnp.asarray(to_numpy(encoder_hidden_states)), + jnp.asarray(to_numpy(encoder_hidden_states_image)), + ) + + self.assertIsNone(actual_mask) + assert_allclose(self, actual_temb, expected_temb, atol=1e-7, rtol=1e-7) + assert_allclose(self, actual_tproj, expected_tproj, atol=1e-7, rtol=1e-7) + assert_allclose(self, actual_text, expected_text, atol=2e-7, rtol=1e-6) + assert_allclose(self, actual_image, expected_image, atol=2e-4, rtol=3e-4) + + def test_wan_feed_forward_parity(self): + hf_module = HFFeedForward(8, inner_dim=16, activation_fn="gelu-approximate").eval() + max_module = WanFeedForward(rngs=self.rngs, dim=8, inner_dim=16, activation_fn="gelu-approximate") + copy_wan_feed_forward_params(max_module, hf_module) + + hidden_states = torch.randn(2, 5, 8) + expected = hf_module(hidden_states) + actual = max_module(jnp.asarray(to_numpy(hidden_states))) + + assert_allclose(self, actual, expected, atol=1e-7, rtol=1e-6) + + def test_wan_transformer_block_parity(self): + hf_module = HFWanTransformerBlock( + dim=8, + ffn_dim=16, + num_heads=2, + qk_norm="rms_norm_across_heads", + cross_attn_norm=True, + eps=1e-6, + ).eval() + + with self.mesh, nn_partitioning.axis_rules(self.logical_axis_rules): + max_module = WanTransformerBlock( + rngs=self.rngs, + dim=8, + ffn_dim=16, + num_heads=2, + qk_norm="rms_norm_across_heads", + cross_attn_norm=True, + eps=1e-6, + attention="dot_product", + flash_min_seq_length=4096, + mesh=self.mesh, + ) + copy_wan_transformer_block_params(max_module, hf_module) + + rope_hf = HFWanRotaryPosEmbed(attention_head_dim=4, patch_size=(1, 2, 2), max_seq_len=32) + rope_max = WanRotaryPosEmbed(attention_head_dim=4, patch_size=(1, 2, 2), max_seq_len=32) + hidden_5d = torch.randn(1, 8, 3, 4, 4) + freqs_cos, freqs_sin = rope_hf(hidden_5d) + rotary_emb = rope_max(jnp.asarray(np.transpose(to_numpy(hidden_5d), (0, 2, 3, 4, 1)))) + + hidden_states = torch.randn(1, 12, 8) + encoder_hidden_states = torch.randn(1, 5, 8) + temb = torch.randn(1, 6, 8) + + expected = hf_module(hidden_states, encoder_hidden_states, temb, (freqs_cos, freqs_sin)) + actual = max_module( + jnp.asarray(to_numpy(hidden_states)), + jnp.asarray(to_numpy(encoder_hidden_states)), + jnp.asarray(to_numpy(temb)), + rotary_emb, + ) + + assert_allclose(self, actual, expected, atol=3e-7, rtol=1e-6) + + def test_wan_model_weight_mapping_covers_all_local_params(self): + cfg = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 4, + "in_channels": 4, + "out_channels": 4, + "text_dim": 8, + "freq_dim": 8, + "ffn_dim": 16, + "num_layers": 1, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "eps": 1e-6, + "image_dim": 4, + "added_kv_proj_dim": None, + "rope_max_seq_len": 32, + "pos_embed_seq_len": None, + } + hf_model = HFWanTransformer3DModel(**cfg).eval() + + with self.mesh, nn_partitioning.axis_rules(self.logical_axis_rules): + max_model = WanModel( + rngs=self.rngs, + scan_layers=False, + mesh=self.mesh, + attention="dot_product", + flash_min_seq_length=4096, + **cfg, + ) + missing_keys, flax_state_dict = map_hf_wan_state_to_local(max_model, hf_model, num_layers=cfg["num_layers"]) + + self.assertFalse(missing_keys, msg=f"Unmapped WAN parameters: {missing_keys}") + self.assertGreater(len(flax_state_dict), 0) + + def test_wan_model_forward_parity(self): + cfg = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 4, + "in_channels": 4, + "out_channels": 4, + "text_dim": 8, + "freq_dim": 8, + "ffn_dim": 16, + "num_layers": 1, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "eps": 1e-6, + "image_dim": 4, + "added_kv_proj_dim": None, + "rope_max_seq_len": 32, + "pos_embed_seq_len": None, + } + hf_model = HFWanTransformer3DModel(**cfg).eval() + + with self.mesh, nn_partitioning.axis_rules(self.logical_axis_rules): + max_model = WanModel( + rngs=self.rngs, + scan_layers=False, + mesh=self.mesh, + attention="dot_product", + flash_min_seq_length=4096, + **cfg, + ) + missing_keys, _ = map_hf_wan_state_to_local(max_model, hf_model, num_layers=cfg["num_layers"]) + self.assertFalse(missing_keys, msg=f"Unmapped WAN parameters: {missing_keys}") + + hidden_states = torch.randn(1, 4, 3, 4, 4) + timestep = torch.tensor([7], dtype=torch.long) + encoder_hidden_states = torch.randn(1, 5, 8) + encoder_hidden_states_image = torch.randn(1, 3, 4) + + with torch.no_grad(): + expected = hf_model( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_image=encoder_hidden_states_image, + ).sample + + actual = max_model( + hidden_states=jnp.asarray(to_numpy(hidden_states)), + timestep=jnp.asarray(to_numpy(timestep)), + encoder_hidden_states=jnp.asarray(to_numpy(encoder_hidden_states)), + encoder_hidden_states_image=jnp.asarray(to_numpy(encoder_hidden_states_image)), + ) + + assert_allclose(self, actual, expected, atol=2e-5, rtol=1e-5) + + +if __name__ == "__main__": + unittest.main()