From f07f1e85cdc2432f02024540243aa9d5363cdb04 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Wed, 1 Apr 2026 23:15:44 -0700 Subject: [PATCH 01/14] Add modular pipeline support for HunyuanVideo 1.5 --- src/diffusers/__init__.py | 6 + src/diffusers/modular_pipelines/__init__.py | 6 + .../hunyuan_video1_5/__init__.py | 47 +++ .../hunyuan_video1_5/before_denoise.py | 324 +++++++++++++++++ .../hunyuan_video1_5/decoders.py | 80 ++++ .../hunyuan_video1_5/denoise.py | 341 ++++++++++++++++++ .../hunyuan_video1_5/encoders.py | 284 +++++++++++++++ .../modular_blocks_hunyuan_video1_5.py | 107 ++++++ .../hunyuan_video1_5/modular_pipeline.py | 114 ++++++ .../modular_pipelines/modular_pipeline.py | 1 + .../hunyuan_video1_5/__init__.py | 0 .../test_modular_pipeline_hunyuan_video1_5.py | 99 +++++ 12 files changed, 1409 insertions(+) create mode 100644 src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py create mode 100644 src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py create mode 100644 src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py create mode 100644 src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py create mode 100644 src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py create mode 100644 src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py create mode 100644 tests/modular_pipelines/hunyuan_video1_5/__init__.py create mode 100644 tests/modular_pipelines/hunyuan_video1_5/test_modular_pipeline_hunyuan_video1_5.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0f74c0bbcb4a..6d79e9733381 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -452,6 +452,9 @@ "HeliosPyramidDistilledAutoBlocks", "HeliosPyramidDistilledModularPipeline", "HeliosPyramidModularPipeline", + "HunyuanVideo15Blocks", + "HunyuanVideo15Image2VideoBlocks", + "HunyuanVideo15ModularPipeline", "QwenImageAutoBlocks", "QwenImageEditAutoBlocks", "QwenImageEditModularPipeline", @@ -1227,6 +1230,9 @@ HeliosPyramidDistilledAutoBlocks, HeliosPyramidDistilledModularPipeline, HeliosPyramidModularPipeline, + HunyuanVideo15Blocks, + HunyuanVideo15Image2VideoBlocks, + HunyuanVideo15ModularPipeline, QwenImageAutoBlocks, QwenImageEditAutoBlocks, QwenImageEditModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index fd9bd691ca87..ae8cb9762f21 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -88,6 +88,11 @@ "QwenImageLayeredModularPipeline", "QwenImageLayeredAutoBlocks", ] + _import_structure["hunyuan_video1_5"] = [ + "HunyuanVideo15Blocks", + "HunyuanVideo15Image2VideoBlocks", + "HunyuanVideo15ModularPipeline", + ] _import_structure["z_image"] = [ "ZImageAutoBlocks", "ZImageModularPipeline", @@ -140,6 +145,7 @@ QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) + from .hunyuan_video1_5 import HunyuanVideo15Blocks, HunyuanVideo15Image2VideoBlocks, HunyuanVideo15ModularPipeline from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import ( Wan22Blocks, diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py new file mode 100644 index 000000000000..73de8277d004 --- /dev/null +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_hunyuan_video1_5"] = ["HunyuanVideo15Blocks", "HunyuanVideo15Image2VideoBlocks"] + _import_structure["modular_pipeline"] = ["HunyuanVideo15ModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_hunyuan_video1_5 import HunyuanVideo15Blocks, HunyuanVideo15Image2VideoBlocks + from .modular_pipeline import HunyuanVideo15ModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py new file mode 100644 index 000000000000..036f320931e1 --- /dev/null +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py @@ -0,0 +1,324 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import numpy as np +import torch + +from ...models import HunyuanVideo15Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import HunyuanVideo15ModularPipeline + + +logger = logging.get_logger(__name__) + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + device=None, + timesteps=None, + sigmas=None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") + if timesteps is not None: + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class HunyuanVideo15TextInputStep(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + @property + def description(self) -> str: + return "Input processing step that determines batch_size and dtype" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", HunyuanVideo15Transformer3DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_videos_per_prompt", default=1), + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), + InputParam("batch_size", type_hint=int), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("batch_size", type_hint=int), + OutputParam("dtype", type_hint=torch.dtype), + ] + + @torch.no_grad() + def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = getattr(block_state, "batch_size", None) or block_state.prompt_embeds.shape[0] + block_state.dtype = components.transformer.dtype + + self.set_block_state(state, block_state) + return components, state + + +class HunyuanVideo15SetTimestepsStep(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("sigmas"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor), + OutputParam("num_inference_steps", type_hint=int), + ] + + @torch.no_grad() + def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + sigmas = block_state.sigmas + if sigmas is None: + sigmas = np.linspace(1.0, 0.0, block_state.num_inference_steps + 1)[:-1] + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + sigmas=sigmas, + ) + + self.set_block_state(state, block_state) + return components, state + + +class HunyuanVideo15PrepareLatentsStep(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + @property + def description(self) -> str: + return "Prepare latents step for text-to-video generation" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("num_frames", type_hint=int, default=121), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_videos_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam("batch_size", required=True, type_hint=int), + InputParam("dtype", type_hint=torch.dtype), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("latents", type_hint=torch.Tensor), + OutputParam("cond_latents_concat", type_hint=torch.Tensor), + OutputParam("mask_concat", type_hint=torch.Tensor), + OutputParam("image_embeds", type_hint=torch.Tensor), + ] + + @torch.no_grad() + def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + dtype = block_state.dtype + + batch_size = block_state.batch_size * block_state.num_videos_per_prompt + height = block_state.height or components.default_height + width = block_state.width or components.default_width + num_frames = block_state.num_frames + + num_channels_latents = components.num_channels_latents + latent_height = height // components.vae_spatial_compression_ratio + latent_width = width // components.vae_spatial_compression_ratio + latent_num_frames = (num_frames - 1) // components.vae_temporal_compression_ratio + 1 + + if block_state.latents is not None: + block_state.latents = block_state.latents.to(device=device, dtype=dtype) + else: + shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width) + block_state.latents = randn_tensor(shape, generator=block_state.generator, device=device, dtype=dtype) + + # T2V: zero cond_latents and mask + b, c, f, h, w = block_state.latents.shape + block_state.cond_latents_concat = torch.zeros(b, c, f, h, w, dtype=dtype, device=device) + block_state.mask_concat = torch.zeros(b, 1, f, h, w, dtype=dtype, device=device) + + # T2V: zero image_embeds + block_state.image_embeds = torch.zeros( + block_state.batch_size, + components.vision_num_semantic_tokens, + components.vision_states_dim, + dtype=dtype, + device=device, + ) + + self.set_block_state(state, block_state) + return components, state + + +def retrieve_latents(encoder_output, generator=None, sample_mode="sample"): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + raise AttributeError("Could not access latents of provided encoder_output") + + +class HunyuanVideo15Image2VideoPrepareLatentsStep(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + @property + def description(self) -> str: + return "Prepare latents step for image-to-video: encodes the first frame and creates conditioning mask" + + @property + def expected_components(self) -> list[ComponentSpec]: + from ...models import AutoencoderKLHunyuanVideo15 + from transformers import SiglipVisionModel, SiglipImageProcessor + return [ + ComponentSpec("vae", AutoencoderKLHunyuanVideo15), + ComponentSpec("image_encoder", SiglipVisionModel), + ComponentSpec("feature_extractor", SiglipImageProcessor), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image", required=True), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("num_frames", type_hint=int, default=121), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_videos_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam("batch_size", required=True, type_hint=int), + InputParam("dtype", type_hint=torch.dtype), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("latents", type_hint=torch.Tensor), + OutputParam("cond_latents_concat", type_hint=torch.Tensor), + OutputParam("mask_concat", type_hint=torch.Tensor), + OutputParam("image_embeds", type_hint=torch.Tensor), + ] + + @torch.no_grad() + def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + dtype = block_state.dtype + + batch_size = block_state.batch_size * block_state.num_videos_per_prompt + height = block_state.height or components.default_height + width = block_state.width or components.default_width + num_frames = block_state.num_frames + + num_channels_latents = components.num_channels_latents + latent_height = height // components.vae_spatial_compression_ratio + latent_width = width // components.vae_spatial_compression_ratio + latent_num_frames = (num_frames - 1) // components.vae_temporal_compression_ratio + 1 + + if block_state.latents is not None: + block_state.latents = block_state.latents.to(device=device, dtype=dtype) + else: + shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width) + block_state.latents = randn_tensor(shape, generator=block_state.generator, device=device, dtype=dtype) + + # Resize/crop image to target resolution (matching upstream flow) + image = block_state.image + from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor + video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=components.vae_spatial_compression_ratio) + height, width = video_processor.calculate_default_height_width( + height=image.size[1], width=image.size[0], target_size=components.target_size + ) + image = video_processor.resize(image, height=height, width=width, resize_mode="crop") + + # Encode image for Siglip embeddings + image_encoder_dtype = next(components.image_encoder.parameters()).dtype + image_inputs = components.feature_extractor.preprocess( + images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True + ) + image_inputs = image_inputs.to(device=device, dtype=image_encoder_dtype) + image_embeds = components.image_encoder(**image_inputs).last_hidden_state + image_embeds = image_embeds.repeat(batch_size, 1, 1) + block_state.image_embeds = image_embeds.to(device=device, dtype=dtype) + + # Encode image for VAE conditioning latents + vae_dtype = components.vae.dtype + image_tensor = video_processor.preprocess(image, height=height, width=width).to(device, dtype=vae_dtype) + image_tensor = image_tensor.unsqueeze(2) + image_latents = retrieve_latents(components.vae.encode(image_tensor), sample_mode="argmax") + image_latents = image_latents * components.vae.config.scaling_factor + + b, c, f, h, w = block_state.latents.shape + latent_condition = image_latents.repeat(batch_size, 1, f, 1, 1) + latent_condition[:, :, 1:, :, :] = 0 + block_state.cond_latents_concat = latent_condition.to(device=device, dtype=dtype) + + latent_mask = torch.zeros(b, 1, f, h, w, dtype=dtype, device=device) + latent_mask[:, :, 0, :, :] = 1.0 + block_state.mask_concat = latent_mask + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py new file mode 100644 index 000000000000..06f2aab6daa7 --- /dev/null +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py @@ -0,0 +1,80 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLHunyuanVideo15 +from ...utils import logging +from ...video_processor import VideoProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) + + +class HunyuanVideo15VaeDecoderStep(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLHunyuanVideo15), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into videos" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor), + InputParam("output_type", default="np", type_hint=str), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "videos", + type_hint=list[list[PIL.Image.Image]] | list[torch.Tensor] | list[np.ndarray], + description="The generated videos", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + if block_state.output_type == "latent": + block_state.videos = block_state.latents + else: + latents = block_state.latents.to(components.vae.dtype) / components.vae.config.scaling_factor + video = components.vae.decode(latents, return_dict=False)[0] + block_state.videos = components.video_processor.postprocess_video(video, output_type=block_state.output_type) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py new file mode 100644 index 000000000000..bcb3f55b2047 --- /dev/null +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py @@ -0,0 +1,341 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import HunyuanVideo15Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam +from .modular_pipeline import HunyuanVideo15ModularPipeline + + +logger = logging.get_logger(__name__) + + +class HunyuanVideo15LoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + @property + def description(self) -> str: + return "Step within the denoising loop that prepares the latent input" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("latents", required=True, type_hint=torch.Tensor), + InputParam("cond_latents_concat", required=True, type_hint=torch.Tensor), + InputParam("mask_concat", required=True, type_hint=torch.Tensor), + InputParam("dtype", required=True, type_hint=torch.dtype), + ] + + @torch.no_grad() + def __call__(self, components: HunyuanVideo15ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = torch.cat( + [block_state.latents, block_state.cond_latents_concat, block_state.mask_concat], dim=1 + ) + return components, block_state + + +class HunyuanVideo15LoopDenoiser(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + def __init__(self, guider_input_fields=None): + if guider_input_fields is None: + guider_input_fields = { + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_attention_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"), + "encoder_hidden_states_2": ("prompt_embeds_2", "negative_prompt_embeds_2"), + "encoder_attention_mask_2": ("prompt_embeds_mask_2", "negative_prompt_embeds_mask_2"), + } + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 1.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", HunyuanVideo15Transformer3DModel), + ] + + @property + def description(self) -> str: + return "Step within the denoising loop that denoises the latents with guidance" + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam("attention_kwargs"), + InputParam("num_inference_steps", required=True, type_hint=int), + InputParam("image_embeds", type_hint=torch.Tensor), + ] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + inputs.append(InputParam(name=value[0], required=True, type_hint=torch.Tensor)) + for neg_name in value[1:]: + inputs.append(InputParam(name=neg_name, type_hint=torch.Tensor)) + else: + inputs.append(InputParam(name=value, required=True, type_hint=torch.Tensor)) + return inputs + + @torch.no_grad() + def __call__( + self, components: HunyuanVideo15ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + timestep = t.expand(block_state.latent_model_input.shape[0]).to(block_state.latent_model_input.dtype) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = { + k: v for k, v in cond_kwargs.items() if k in self._guider_input_fields.keys() + } + + context_name = getattr(guider_state_batch, components.guider._identifier_key, None) + with components.transformer.cache_context(context_name): + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + image_embeds=block_state.image_embeds, + timestep=timestep, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class HunyuanVideo15LoopAfterDenoiser(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step within the denoising loop that updates the latents" + + @torch.no_grad() + def __call__(self, components: HunyuanVideo15ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class HunyuanVideo15DenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "hunyuan-video-1.5" + + @property + def description(self) -> str: + return "Pipeline block that iteratively denoises the latents over timesteps" + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", HunyuanVideo15Transformer3DModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam("timesteps", required=True, type_hint=torch.Tensor), + InputParam("num_inference_steps", required=True, type_hint=int), + ] + + @torch.no_grad() + def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + return components, state + + +class HunyuanVideo15DenoiseStep(HunyuanVideo15DenoiseLoopWrapper): + block_classes = [ + HunyuanVideo15LoopBeforeDenoiser, + HunyuanVideo15LoopDenoiser(), + HunyuanVideo15LoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents.\n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `HunyuanVideo15LoopBeforeDenoiser`\n" + " - `HunyuanVideo15LoopDenoiser`\n" + " - `HunyuanVideo15LoopAfterDenoiser`\n" + "This block supports text-to-video tasks." + ) + + +class HunyuanVideo15Image2VideoLoopDenoiser(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + def __init__(self, guider_input_fields=None): + if guider_input_fields is None: + guider_input_fields = { + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_attention_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"), + "encoder_hidden_states_2": ("prompt_embeds_2", "negative_prompt_embeds_2"), + "encoder_attention_mask_2": ("prompt_embeds_mask_2", "negative_prompt_embeds_mask_2"), + } + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 1.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", HunyuanVideo15Transformer3DModel), + ] + + @property + def description(self) -> str: + return "I2V denoiser with MeanFlow timestep_r support" + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam("attention_kwargs"), + InputParam("num_inference_steps", required=True, type_hint=int), + InputParam("image_embeds", type_hint=torch.Tensor), + InputParam("timesteps", required=True, type_hint=torch.Tensor), + ] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + inputs.append(InputParam(name=value[0], required=True, type_hint=torch.Tensor)) + for neg_name in value[1:]: + inputs.append(InputParam(name=neg_name, type_hint=torch.Tensor)) + else: + inputs.append(InputParam(name=value, required=True, type_hint=torch.Tensor)) + return inputs + + @torch.no_grad() + def __call__( + self, components: HunyuanVideo15ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + timestep = t.expand(block_state.latent_model_input.shape[0]).to(block_state.latent_model_input.dtype) + + # MeanFlow: compute timestep_r + timestep_r = None + if components.transformer.config.use_meanflow: + if i == len(block_state.timesteps) - 1: + timestep_r = torch.tensor([0.0], device=timestep.device) + else: + timestep_r = block_state.timesteps[i + 1] + timestep_r = timestep_r.expand(block_state.latents.shape[0]).to(block_state.latents.dtype) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = { + k: v for k, v in cond_kwargs.items() if k in self._guider_input_fields.keys() + } + + context_name = getattr(guider_state_batch, components.guider._identifier_key, None) + with components.transformer.cache_context(context_name): + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + image_embeds=block_state.image_embeds, + timestep=timestep, + timestep_r=timestep_r, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class HunyuanVideo15Image2VideoDenoiseStep(HunyuanVideo15DenoiseLoopWrapper): + block_classes = [ + HunyuanVideo15LoopBeforeDenoiser, + HunyuanVideo15Image2VideoLoopDenoiser(), + HunyuanVideo15LoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step for image-to-video with MeanFlow support.\n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `HunyuanVideo15LoopBeforeDenoiser`\n" + " - `HunyuanVideo15Image2VideoLoopDenoiser`\n" + " - `HunyuanVideo15LoopAfterDenoiser`" + ) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py new file mode 100644 index 000000000000..1d81d445c89b --- /dev/null +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py @@ -0,0 +1,284 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import HunyuanVideo15ModularPipeline + +from ...pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5 import ( + format_text_input, + extract_glyph_texts, +) + + +logger = logging.get_logger(__name__) + + +class HunyuanVideo15TextEncoderStep(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + @property + def description(self) -> str: + return "Dual text encoder step using Qwen2.5-VL (MLLM) and ByT5 (glyph text)" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", Qwen2_5_VLTextModel), + ComponentSpec("tokenizer", Qwen2Tokenizer), + ComponentSpec("text_encoder_2", T5EncoderModel), + ComponentSpec("tokenizer_2", ByT5Tokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 1.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam("prompt_embeds", type_hint=torch.Tensor), + InputParam("prompt_embeds_mask", type_hint=torch.Tensor), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor), + InputParam("negative_prompt_embeds_mask", type_hint=torch.Tensor), + InputParam("prompt_embeds_2", type_hint=torch.Tensor), + InputParam("prompt_embeds_mask_2", type_hint=torch.Tensor), + InputParam("negative_prompt_embeds_2", type_hint=torch.Tensor), + InputParam("negative_prompt_embeds_mask_2", type_hint=torch.Tensor), + InputParam("num_videos_per_prompt", type_hint=int, default=1), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam("prompt_embeds_mask", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam("negative_prompt_embeds_mask", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam("prompt_embeds_2", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam("prompt_embeds_mask_2", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam("negative_prompt_embeds_2", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam("negative_prompt_embeds_mask_2", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + ] + + @staticmethod + def _get_mllm_prompt_embeds( + text_encoder, + tokenizer, + prompt, + device, + tokenizer_max_length=1000, + num_hidden_layers_to_skip=2, + system_message="You are a helpful assistant. Describe the video by detailing the following aspects: " + "1. The main content and theme of the video. " + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. " + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. " + "4. background environment, light, style and atmosphere. " + "5. camera angles, movements, and transitions used in the video.", + crop_start=108, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = format_text_input(prompt, system_message) + + text_inputs = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding="max_length", + max_length=tokenizer_max_length + crop_start, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + return prompt_embeds, prompt_attention_mask + + @staticmethod + def _get_byt5_prompt_embeds(tokenizer, text_encoder, prompt, device, tokenizer_max_length=256): + prompt = [prompt] if isinstance(prompt, str) else prompt + glyph_texts = [extract_glyph_texts(p) for p in prompt] + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype + ) + glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64) + else: + txt_tokens = tokenizer( + glyph_text, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + glyph_text_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0].to(device=device) + glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) + + prompt_embeds_list.append(glyph_text_embeds) + prompt_embeds_mask_list.append(glyph_text_embeds_mask) + + return torch.cat(prompt_embeds_list, dim=0), torch.cat(prompt_embeds_mask_list, dim=0) + + @staticmethod + def encode_prompt( + components, + prompt, + device=None, + dtype=None, + batch_size=1, + num_videos_per_prompt=1, + prompt_embeds=None, + prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + ): + device = device or components._execution_device + dtype = dtype or components.text_encoder.dtype + + if prompt is None: + prompt = [""] * batch_size + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = HunyuanVideo15TextEncoderStep._get_mllm_prompt_embeds( + tokenizer=components.tokenizer, + text_encoder=components.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=components.tokenizer_max_length, + system_message=components.system_message, + crop_start=components.prompt_template_encode_start_idx, + ) + + if prompt_embeds_2 is None: + prompt_embeds_2, prompt_embeds_mask_2 = HunyuanVideo15TextEncoderStep._get_byt5_prompt_embeds( + tokenizer=components.tokenizer_2, + text_encoder=components.text_encoder_2, + prompt=prompt, + device=device, + tokenizer_max_length=components.tokenizer_2_max_length, + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1).view( + batch_size * num_videos_per_prompt, seq_len, -1 + ) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1).view( + batch_size * num_videos_per_prompt, seq_len + ) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1).view( + batch_size * num_videos_per_prompt, seq_len_2, -1 + ) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1).view( + batch_size * num_videos_per_prompt, seq_len_2 + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) + prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) + prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + @torch.no_grad() + def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + dtype = components.transformer.dtype + + prompt = block_state.prompt + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = getattr(block_state, "prompt_embeds", torch.empty(1)).shape[0] + + ( + block_state.prompt_embeds, + block_state.prompt_embeds_mask, + block_state.prompt_embeds_2, + block_state.prompt_embeds_mask_2, + ) = self.encode_prompt( + components=components, + prompt=prompt, + device=device, + dtype=dtype, + batch_size=batch_size, + num_videos_per_prompt=getattr(block_state, "num_videos_per_prompt", 1), + prompt_embeds=getattr(block_state, "prompt_embeds", None), + prompt_embeds_mask=getattr(block_state, "prompt_embeds_mask", None), + prompt_embeds_2=getattr(block_state, "prompt_embeds_2", None), + prompt_embeds_mask_2=getattr(block_state, "prompt_embeds_mask_2", None), + ) + + if components.guider._enabled and components.guider.num_conditions > 1: + negative_prompt = block_state.negative_prompt + ( + block_state.negative_prompt_embeds, + block_state.negative_prompt_embeds_mask, + block_state.negative_prompt_embeds_2, + block_state.negative_prompt_embeds_mask_2, + ) = self.encode_prompt( + components=components, + prompt=negative_prompt, + device=device, + dtype=dtype, + batch_size=batch_size, + num_videos_per_prompt=getattr(block_state, "num_videos_per_prompt", 1), + prompt_embeds=getattr(block_state, "negative_prompt_embeds", None), + prompt_embeds_mask=getattr(block_state, "negative_prompt_embeds_mask", None), + prompt_embeds_2=getattr(block_state, "negative_prompt_embeds_2", None), + prompt_embeds_mask_2=getattr(block_state, "negative_prompt_embeds_mask_2", None), + ) + + state.set("batch_size", batch_size) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py new file mode 100644 index 000000000000..f56cd5fb6119 --- /dev/null +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py @@ -0,0 +1,107 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + HunyuanVideo15Image2VideoPrepareLatentsStep, + HunyuanVideo15PrepareLatentsStep, + HunyuanVideo15SetTimestepsStep, + HunyuanVideo15TextInputStep, +) +from .decoders import HunyuanVideo15VaeDecoderStep +from .denoise import HunyuanVideo15DenoiseStep, HunyuanVideo15Image2VideoDenoiseStep +from .encoders import HunyuanVideo15TextEncoderStep + + +logger = logging.get_logger(__name__) + + +# auto_docstring +class HunyuanVideo15CoreDenoiseStep(SequentialPipelineBlocks): + model_name = "hunyuan-video-1.5" + block_classes = [ + HunyuanVideo15TextInputStep, + HunyuanVideo15SetTimestepsStep, + HunyuanVideo15PrepareLatentsStep, + HunyuanVideo15DenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return "Denoise block that takes encoded conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class HunyuanVideo15Blocks(SequentialPipelineBlocks): + model_name = "hunyuan-video-1.5" + block_classes = [ + HunyuanVideo15TextEncoderStep, + HunyuanVideo15CoreDenoiseStep, + HunyuanVideo15VaeDecoderStep, + ] + block_names = ["text_encoder", "denoise", "decode"] + + @property + def description(self): + return "Modular pipeline blocks for HunyuanVideo 1.5 text-to-video." + + @property + def outputs(self): + return [OutputParam.template("videos")] + + +# auto_docstring +class HunyuanVideo15Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "hunyuan-video-1.5" + block_classes = [ + HunyuanVideo15TextInputStep, + HunyuanVideo15SetTimestepsStep, + HunyuanVideo15Image2VideoPrepareLatentsStep, + HunyuanVideo15Image2VideoDenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return "I2V denoise block that takes encoded conditions, an image, and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class HunyuanVideo15Image2VideoBlocks(SequentialPipelineBlocks): + model_name = "hunyuan-video-1.5" + block_classes = [ + HunyuanVideo15TextEncoderStep, + HunyuanVideo15Image2VideoCoreDenoiseStep, + HunyuanVideo15VaeDecoderStep, + ] + block_names = ["text_encoder", "denoise", "decode"] + + @property + def description(self): + return "Modular pipeline blocks for HunyuanVideo 1.5 image-to-video." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py new file mode 100644 index 000000000000..75acd21e248f --- /dev/null +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py @@ -0,0 +1,114 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...loaders import HunyuanVideoLoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) + + +class HunyuanVideo15ModularPipeline( + ModularPipeline, + HunyuanVideoLoraLoaderMixin, +): + """ + A ModularPipeline for HunyuanVideo 1.5. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "HunyuanVideo15Blocks" + + @property + def vae_spatial_compression_ratio(self): + if getattr(self, "vae", None) is not None: + return self.vae.spatial_compression_ratio + return 16 + + @property + def vae_temporal_compression_ratio(self): + if getattr(self, "vae", None) is not None: + return self.vae.temporal_compression_ratio + return 4 + + @property + def num_channels_latents(self): + if getattr(self, "vae", None) is not None: + return self.vae.config.latent_channels + return 32 + + @property + def target_size(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.target_size + return 640 + + @property + def default_aspect_ratio(self): + return (16, 9) + + @property + def default_height(self): + from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor + processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + h, w = processor.calculate_default_height_width( + self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size + ) + return h + + @property + def default_width(self): + from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor + processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + h, w = processor.calculate_default_height_width( + self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size + ) + return w + + @property + def tokenizer_max_length(self): + return 1000 + + @property + def tokenizer_2_max_length(self): + return 256 + + @property + def system_message(self): + return ( + "You are a helpful assistant. Describe the video by detailing the following aspects: " + "1. The main content and theme of the video. " + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. " + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. " + "4. background environment, light, style and atmosphere. " + "5. camera angles, movements, and transitions used in the video." + ) + + @property + def prompt_template_encode_start_idx(self): + return 108 + + @property + def vision_num_semantic_tokens(self): + return 729 + + @property + def vision_states_dim(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.image_embed_dim + return 1152 + + diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 9cd2f9f5c6ae..bad94e0df8d3 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -132,6 +132,7 @@ def _helios_pyramid_map_fn(config_dict=None): ("z-image", _create_default_map_fn("ZImageModularPipeline")), ("helios", _create_default_map_fn("HeliosModularPipeline")), ("helios-pyramid", _helios_pyramid_map_fn), + ("hunyuan-video-1.5", _create_default_map_fn("HunyuanVideo15ModularPipeline")), ] ) diff --git a/tests/modular_pipelines/hunyuan_video1_5/__init__.py b/tests/modular_pipelines/hunyuan_video1_5/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/hunyuan_video1_5/test_modular_pipeline_hunyuan_video1_5.py b/tests/modular_pipelines/hunyuan_video1_5/test_modular_pipeline_hunyuan_video1_5.py new file mode 100644 index 000000000000..e0107db04d33 --- /dev/null +++ b/tests/modular_pipelines/hunyuan_video1_5/test_modular_pipeline_hunyuan_video1_5.py @@ -0,0 +1,99 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from diffusers.modular_pipelines import ( + HunyuanVideo15Blocks, + HunyuanVideo15Image2VideoBlocks, + HunyuanVideo15ModularPipeline, +) + + +class TestHunyuanVideo15ModularPipelineStructure(unittest.TestCase): + def test_import(self): + blocks = HunyuanVideo15Blocks() + self.assertIsNotNone(blocks) + + def test_pipeline_class(self): + blocks = HunyuanVideo15Blocks() + pipe = blocks.init_pipeline() + self.assertIsInstance(pipe, HunyuanVideo15ModularPipeline) + + def test_block_names(self): + blocks = HunyuanVideo15Blocks() + self.assertEqual(blocks.block_names, ["text_encoder", "denoise", "decode"]) + + def test_denoise_sub_blocks(self): + blocks = HunyuanVideo15Blocks() + denoise = blocks.sub_blocks["denoise"] + self.assertEqual( + list(denoise.sub_blocks.keys()), + ["input", "set_timesteps", "prepare_latents", "denoise"], + ) + + def test_denoise_loop_sub_blocks(self): + blocks = HunyuanVideo15Blocks() + denoise_loop = blocks.sub_blocks["denoise"].sub_blocks["denoise"] + self.assertEqual( + list(denoise_loop.sub_blocks.keys()), + ["before_denoiser", "denoiser", "after_denoiser"], + ) + + def test_expected_components(self): + blocks = HunyuanVideo15Blocks() + comp_names = {c.name for c in blocks.expected_components} + self.assertIn("transformer", comp_names) + self.assertIn("vae", comp_names) + self.assertIn("text_encoder", comp_names) + self.assertIn("text_encoder_2", comp_names) + self.assertIn("tokenizer", comp_names) + self.assertIn("tokenizer_2", comp_names) + self.assertIn("scheduler", comp_names) + self.assertIn("guider", comp_names) + + def test_model_name(self): + blocks = HunyuanVideo15Blocks() + self.assertEqual(blocks.model_name, "hunyuan-video-1.5") + + def test_i2v_import(self): + blocks = HunyuanVideo15Image2VideoBlocks() + self.assertIsNotNone(blocks) + + def test_i2v_pipeline_class(self): + blocks = HunyuanVideo15Image2VideoBlocks() + pipe = blocks.init_pipeline() + self.assertIsInstance(pipe, HunyuanVideo15ModularPipeline) + + def test_i2v_has_image_input(self): + blocks = HunyuanVideo15Image2VideoBlocks() + input_names = {inp.name for inp in blocks.inputs} + self.assertIn("image", input_names) + + def test_i2v_has_image_encoder(self): + blocks = HunyuanVideo15Image2VideoBlocks() + comp_names = {c.name for c in blocks.expected_components} + self.assertIn("image_encoder", comp_names) + self.assertIn("feature_extractor", comp_names) + + def test_top_level_export(self): + from diffusers import HunyuanVideo15Blocks as Top, HunyuanVideo15ModularPipeline as TopPipe + + self.assertIs(Top, HunyuanVideo15Blocks) + self.assertIs(TopPipe, HunyuanVideo15ModularPipeline) + + +if __name__ == "__main__": + unittest.main() From 609063890a45d6e1b53431cdd6330f40f0e33af8 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Wed, 1 Apr 2026 23:39:35 -0700 Subject: [PATCH 02/14] Fix I2V latent/cond spatial dimension mismatch --- .../hunyuan_video1_5/before_denoise.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py index 036f320931e1..f6e13de635ad 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py @@ -270,10 +270,17 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta dtype = block_state.dtype batch_size = block_state.batch_size * block_state.num_videos_per_prompt - height = block_state.height or components.default_height - width = block_state.width or components.default_width num_frames = block_state.num_frames + # Resize/crop image to target resolution first (determines latent dims) + image = block_state.image + from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor + video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=components.vae_spatial_compression_ratio) + height, width = video_processor.calculate_default_height_width( + height=image.size[1], width=image.size[0], target_size=components.target_size + ) + image = video_processor.resize(image, height=height, width=width, resize_mode="crop") + num_channels_latents = components.num_channels_latents latent_height = height // components.vae_spatial_compression_ratio latent_width = width // components.vae_spatial_compression_ratio @@ -285,15 +292,6 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width) block_state.latents = randn_tensor(shape, generator=block_state.generator, device=device, dtype=dtype) - # Resize/crop image to target resolution (matching upstream flow) - image = block_state.image - from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor - video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=components.vae_spatial_compression_ratio) - height, width = video_processor.calculate_default_height_width( - height=image.size[1], width=image.size[0], target_size=components.target_size - ) - image = video_processor.resize(image, height=height, width=width, resize_mode="crop") - # Encode image for Siglip embeddings image_encoder_dtype = next(components.image_encoder.parameters()).dtype image_inputs = components.feature_extractor.preprocess( From 85802a7ec19b8517ad5c6ee41bf96317308b4485 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 02:19:09 -0700 Subject: [PATCH 03/14] Fix guidance_scale default to 7.5 matching ClassifierFreeGuidance --- src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py | 4 ++-- src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py index bcb3f55b2047..d735d1a07663 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py @@ -80,7 +80,7 @@ def expected_components(self) -> list[ComponentSpec]: ComponentSpec( "guider", ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 1.0}), + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config", ), ComponentSpec("transformer", HunyuanVideo15Transformer3DModel), @@ -252,7 +252,7 @@ def expected_components(self) -> list[ComponentSpec]: ComponentSpec( "guider", ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 1.0}), + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config", ), ComponentSpec("transformer", HunyuanVideo15Transformer3DModel), diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py index 1d81d445c89b..f21fccbe8abf 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py @@ -48,7 +48,7 @@ def expected_components(self) -> list[ComponentSpec]: ComponentSpec( "guider", ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 1.0}), + config=FrozenDict({"guidance_scale": 7.5}), default_creation_method="from_config", ), ] From a3d814bbb1265423a386ebf84e0040c8d29965b4 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 04:45:07 -0700 Subject: [PATCH 04/14] Fix tokenizer type: use Qwen2TokenizerFast to match model --- src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py index f21fccbe8abf..1bf50a0888cb 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -from transformers import ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2Tokenizer, T5EncoderModel +from transformers import ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2TokenizerFast, T5EncoderModel from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance @@ -42,7 +42,7 @@ def description(self) -> str: def expected_components(self) -> list[ComponentSpec]: return [ ComponentSpec("text_encoder", Qwen2_5_VLTextModel), - ComponentSpec("tokenizer", Qwen2Tokenizer), + ComponentSpec("tokenizer", Qwen2TokenizerFast), ComponentSpec("text_encoder_2", T5EncoderModel), ComponentSpec("tokenizer_2", ByT5Tokenizer), ComponentSpec( From 22e793915c13a9322b898b7836bb6cee0e559c05 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 05:53:47 -0700 Subject: [PATCH 05/14] Fix system message string formatting to match standard pipeline --- .../hunyuan_video1_5/encoders.py | 14 ++++++++------ .../hunyuan_video1_5/modular_pipeline.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py index 1bf50a0888cb..d38d1cc89c20 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py @@ -90,12 +90,14 @@ def _get_mllm_prompt_embeds( device, tokenizer_max_length=1000, num_hidden_layers_to_skip=2, - system_message="You are a helpful assistant. Describe the video by detailing the following aspects: " - "1. The main content and theme of the video. " - "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. " - "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. " - "4. background environment, light, style and atmosphere. " - "5. camera angles, movements, and transitions used in the video.", + # fmt: off + system_message="You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", + # fmt: on crop_start=108, ): prompt = [prompt] if isinstance(prompt, str) else prompt diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py index 75acd21e248f..362904b58557 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py @@ -88,14 +88,14 @@ def tokenizer_2_max_length(self): @property def system_message(self): - return ( - "You are a helpful assistant. Describe the video by detailing the following aspects: " - "1. The main content and theme of the video. " - "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. " - "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. " - "4. background environment, light, style and atmosphere. " - "5. camera angles, movements, and transitions used in the video." - ) + # fmt: off + return "You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video." + # fmt: on @property def prompt_template_encode_start_idx(self): From 00564fe269787a9cf4661296697cb7a7aa078659 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 06:09:52 -0700 Subject: [PATCH 06/14] Rewrite HunyuanVideo 1.5 modular: use standard pipeline methods directly --- .../hunyuan_video1_5/__init__.py | 4 +- .../hunyuan_video1_5/before_denoise.py | 171 ++++------------ .../hunyuan_video1_5/decoders.py | 5 +- .../hunyuan_video1_5/denoise.py | 148 +++----------- .../hunyuan_video1_5/encoders.py | 182 +++--------------- .../modular_blocks_hunyuan_video1_5.py | 42 +--- .../hunyuan_video1_5/modular_pipeline.py | 58 ++---- .../test_modular_pipeline_hunyuan_video1_5.py | 21 -- 8 files changed, 104 insertions(+), 527 deletions(-) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py index 73de8277d004..3ad7d17a8357 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py @@ -21,7 +21,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["modular_blocks_hunyuan_video1_5"] = ["HunyuanVideo15Blocks", "HunyuanVideo15Image2VideoBlocks"] + _import_structure["modular_blocks_hunyuan_video1_5"] = ["HunyuanVideo15Blocks"] _import_structure["modular_pipeline"] = ["HunyuanVideo15ModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -31,7 +31,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .modular_blocks_hunyuan_video1_5 import HunyuanVideo15Blocks, HunyuanVideo15Image2VideoBlocks + from .modular_blocks_hunyuan_video1_5 import HunyuanVideo15Blocks from .modular_pipeline import HunyuanVideo15ModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py index f6e13de635ad..8c1b65d12c63 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py @@ -18,6 +18,7 @@ import torch from ...models import HunyuanVideo15Transformer3DModel +from ...pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor @@ -29,6 +30,7 @@ logger = logging.get_logger(__name__) +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps=None, @@ -89,10 +91,8 @@ def intermediate_outputs(self) -> list[OutputParam]: @torch.no_grad() def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.batch_size = getattr(block_state, "batch_size", None) or block_state.prompt_embeds.shape[0] block_state.dtype = components.transformer.dtype - self.set_block_state(state, block_state) return components, state @@ -102,9 +102,7 @@ class HunyuanVideo15SetTimestepsStep(ModularPipelineBlocks): @property def expected_components(self) -> list[ComponentSpec]: - return [ - ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), - ] + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] @property def description(self) -> str: @@ -124,6 +122,7 @@ def intermediate_outputs(self) -> list[OutputParam]: OutputParam("num_inference_steps", type_hint=int), ] + # Copied from pipeline_hunyuan_video1_5.py line 702-704 @torch.no_grad() def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -134,10 +133,7 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta sigmas = np.linspace(1.0, 0.0, block_state.num_inference_steps + 1)[:-1] block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( - components.scheduler, - block_state.num_inference_steps, - device, - sigmas=sigmas, + components.scheduler, block_state.num_inference_steps, device, sigmas=sigmas ) self.set_block_state(state, block_state) @@ -149,7 +145,7 @@ class HunyuanVideo15PrepareLatentsStep(ModularPipelineBlocks): @property def description(self) -> str: - return "Prepare latents step for text-to-video generation" + return "Prepare latents, conditioning latents, mask, and image_embeds for T2V" @property def inputs(self) -> list[InputParam]: @@ -173,34 +169,46 @@ def intermediate_outputs(self) -> list[OutputParam]: OutputParam("image_embeds", type_hint=torch.Tensor), ] + # Copied from pipeline_hunyuan_video1_5.py lines 652-655, 706-725 @torch.no_grad() def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) device = components._execution_device dtype = block_state.dtype + # Calculate default height/width if not provided (line 652-655) + height = block_state.height + width = block_state.width + if height is None and width is None: + height, width = components.video_processor.calculate_default_height_width( + components.default_aspect_ratio[1], components.default_aspect_ratio[0], components.target_size + ) + batch_size = block_state.batch_size * block_state.num_videos_per_prompt - height = block_state.height or components.default_height - width = block_state.width or components.default_width num_frames = block_state.num_frames - num_channels_latents = components.num_channels_latents - latent_height = height // components.vae_spatial_compression_ratio - latent_width = width // components.vae_spatial_compression_ratio - latent_num_frames = (num_frames - 1) // components.vae_temporal_compression_ratio + 1 - - if block_state.latents is not None: - block_state.latents = block_state.latents.to(device=device, dtype=dtype) - else: - shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width) - block_state.latents = randn_tensor(shape, generator=block_state.generator, device=device, dtype=dtype) + # Copied from HunyuanVideo15Pipeline.prepare_latents (lines 477-505, 707-717) + block_state.latents = HunyuanVideo15Pipeline.prepare_latents( + components, + batch_size, + components.num_channels_latents, + height, + width, + num_frames, + dtype, + device, + block_state.generator, + block_state.latents, + ) - # T2V: zero cond_latents and mask - b, c, f, h, w = block_state.latents.shape - block_state.cond_latents_concat = torch.zeros(b, c, f, h, w, dtype=dtype, device=device) - block_state.mask_concat = torch.zeros(b, 1, f, h, w, dtype=dtype, device=device) + # Copied from HunyuanVideo15Pipeline.prepare_cond_latents_and_mask (lines 508-524, 718) + cond_latents_concat, mask_concat = HunyuanVideo15Pipeline.prepare_cond_latents_and_mask( + components, block_state.latents, dtype, device + ) + block_state.cond_latents_concat = cond_latents_concat + block_state.mask_concat = mask_concat - # T2V: zero image_embeds + # T2V: zero image_embeds (line 719-725) block_state.image_embeds = torch.zeros( block_state.batch_size, components.vision_num_semantic_tokens, @@ -211,112 +219,3 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta self.set_block_state(state, block_state) return components, state - - -def retrieve_latents(encoder_output, generator=None, sample_mode="sample"): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - raise AttributeError("Could not access latents of provided encoder_output") - - -class HunyuanVideo15Image2VideoPrepareLatentsStep(ModularPipelineBlocks): - model_name = "hunyuan-video-1.5" - - @property - def description(self) -> str: - return "Prepare latents step for image-to-video: encodes the first frame and creates conditioning mask" - - @property - def expected_components(self) -> list[ComponentSpec]: - from ...models import AutoencoderKLHunyuanVideo15 - from transformers import SiglipVisionModel, SiglipImageProcessor - return [ - ComponentSpec("vae", AutoencoderKLHunyuanVideo15), - ComponentSpec("image_encoder", SiglipVisionModel), - ComponentSpec("feature_extractor", SiglipImageProcessor), - ] - - @property - def inputs(self) -> list[InputParam]: - return [ - InputParam("image", required=True), - InputParam("height", type_hint=int), - InputParam("width", type_hint=int), - InputParam("num_frames", type_hint=int, default=121), - InputParam("latents", type_hint=torch.Tensor | None), - InputParam("num_videos_per_prompt", type_hint=int, default=1), - InputParam("generator"), - InputParam("batch_size", required=True, type_hint=int), - InputParam("dtype", type_hint=torch.dtype), - ] - - @property - def intermediate_outputs(self) -> list[OutputParam]: - return [ - OutputParam("latents", type_hint=torch.Tensor), - OutputParam("cond_latents_concat", type_hint=torch.Tensor), - OutputParam("mask_concat", type_hint=torch.Tensor), - OutputParam("image_embeds", type_hint=torch.Tensor), - ] - - @torch.no_grad() - def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - device = components._execution_device - dtype = block_state.dtype - - batch_size = block_state.batch_size * block_state.num_videos_per_prompt - num_frames = block_state.num_frames - - # Resize/crop image to target resolution first (determines latent dims) - image = block_state.image - from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor - video_processor = HunyuanVideo15ImageProcessor(vae_scale_factor=components.vae_spatial_compression_ratio) - height, width = video_processor.calculate_default_height_width( - height=image.size[1], width=image.size[0], target_size=components.target_size - ) - image = video_processor.resize(image, height=height, width=width, resize_mode="crop") - - num_channels_latents = components.num_channels_latents - latent_height = height // components.vae_spatial_compression_ratio - latent_width = width // components.vae_spatial_compression_ratio - latent_num_frames = (num_frames - 1) // components.vae_temporal_compression_ratio + 1 - - if block_state.latents is not None: - block_state.latents = block_state.latents.to(device=device, dtype=dtype) - else: - shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width) - block_state.latents = randn_tensor(shape, generator=block_state.generator, device=device, dtype=dtype) - - # Encode image for Siglip embeddings - image_encoder_dtype = next(components.image_encoder.parameters()).dtype - image_inputs = components.feature_extractor.preprocess( - images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True - ) - image_inputs = image_inputs.to(device=device, dtype=image_encoder_dtype) - image_embeds = components.image_encoder(**image_inputs).last_hidden_state - image_embeds = image_embeds.repeat(batch_size, 1, 1) - block_state.image_embeds = image_embeds.to(device=device, dtype=dtype) - - # Encode image for VAE conditioning latents - vae_dtype = components.vae.dtype - image_tensor = video_processor.preprocess(image, height=height, width=width).to(device, dtype=vae_dtype) - image_tensor = image_tensor.unsqueeze(2) - image_latents = retrieve_latents(components.vae.encode(image_tensor), sample_mode="argmax") - image_latents = image_latents * components.vae.config.scaling_factor - - b, c, f, h, w = block_state.latents.shape - latent_condition = image_latents.repeat(batch_size, 1, f, 1, 1) - latent_condition[:, :, 1:, :, :] = 0 - block_state.cond_latents_concat = latent_condition.to(device=device, dtype=dtype) - - latent_mask = torch.zeros(b, 1, f, h, w, dtype=dtype, device=device) - latent_mask[:, :, 0, :, :] = 1.0 - block_state.mask_concat = latent_mask - - self.set_block_state(state, block_state) - return components, state diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py index 06f2aab6daa7..f5eddb16b5ed 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py @@ -20,8 +20,8 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKLHunyuanVideo15 +from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor from ...utils import logging -from ...video_processor import VideoProcessor from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam @@ -38,7 +38,7 @@ def expected_components(self) -> list[ComponentSpec]: ComponentSpec("vae", AutoencoderKLHunyuanVideo15), ComponentSpec( "video_processor", - VideoProcessor, + HunyuanVideo15ImageProcessor, config=FrozenDict({"vae_scale_factor": 16}), default_creation_method="from_config", ), @@ -65,6 +65,7 @@ def intermediate_outputs(self) -> list[OutputParam]: ) ] + # Copied from pipeline_hunyuan_video1_5.py lines 823-829 @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py index d735d1a07663..e66e4701b4cb 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py @@ -47,9 +47,9 @@ def inputs(self) -> list[InputParam]: InputParam("latents", required=True, type_hint=torch.Tensor), InputParam("cond_latents_concat", required=True, type_hint=torch.Tensor), InputParam("mask_concat", required=True, type_hint=torch.Tensor), - InputParam("dtype", required=True, type_hint=torch.dtype), ] + # Copied from pipeline_hunyuan_video1_5.py line 737 @torch.no_grad() def __call__(self, components: HunyuanVideo15ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): block_state.latent_model_input = torch.cat( @@ -106,23 +106,34 @@ def inputs(self) -> list[InputParam]: inputs.append(InputParam(name=value, required=True, type_hint=torch.Tensor)) return inputs + # Copied from pipeline_hunyuan_video1_5.py lines 739-803 @torch.no_grad() def __call__( self, components: HunyuanVideo15ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor ) -> PipelineState: + timestep = t.expand(block_state.latent_model_input.shape[0]).to(block_state.latent_model_input.dtype) + + # Step 1: Collect model inputs + guider_inputs = { + input_name: tuple(getattr(block_state, v) for v in value) if isinstance(value, tuple) else getattr(block_state, value) + for input_name, value in self._guider_input_fields.items() + } + + # Step 2: Update guider state components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) - timestep = t.expand(block_state.latent_model_input.shape[0]).to(block_state.latent_model_input.dtype) + # Step 3: Prepare batched inputs + guider_state = components.guider.prepare_inputs(guider_inputs) + # Step 4: Run denoiser for each batch for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) - cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = { - k: v for k, v in cond_kwargs.items() if k in self._guider_input_fields.keys() + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() } - context_name = getattr(guider_state_batch, components.guider._identifier_key, None) + context_name = getattr(guider_state_batch, components.guider._identifier_key) with components.transformer.cache_context(context_name): guider_state_batch.noise_pred = components.transformer( hidden_states=block_state.latent_model_input, @@ -132,8 +143,10 @@ def __call__( return_dict=False, **cond_kwargs, )[0] + components.guider.cleanup_models(components.transformer) + # Step 5: Combine predictions block_state.noise_pred = components.guider(guider_state)[0] return components, block_state @@ -144,22 +157,18 @@ class HunyuanVideo15LoopAfterDenoiser(ModularPipelineBlocks): @property def expected_components(self) -> list[ComponentSpec]: - return [ - ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), - ] + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] @property def description(self) -> str: return "Step within the denoising loop that updates the latents" + # Copied from pipeline_hunyuan_video1_5.py lines 805-812 @torch.no_grad() def __call__(self, components: HunyuanVideo15ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): latents_dtype = block_state.latents.dtype block_state.latents = components.scheduler.step( - block_state.noise_pred, - t, - block_state.latents, - return_dict=False, + block_state.noise_pred, t, block_state.latents, return_dict=False )[0] if block_state.latents.dtype != latents_dtype: @@ -222,120 +231,9 @@ class HunyuanVideo15DenoiseStep(HunyuanVideo15DenoiseLoopWrapper): def description(self) -> str: return ( "Denoise step that iteratively denoises the latents.\n" - "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + "At each iteration:\n" " - `HunyuanVideo15LoopBeforeDenoiser`\n" " - `HunyuanVideo15LoopDenoiser`\n" " - `HunyuanVideo15LoopAfterDenoiser`\n" "This block supports text-to-video tasks." ) - - -class HunyuanVideo15Image2VideoLoopDenoiser(ModularPipelineBlocks): - model_name = "hunyuan-video-1.5" - - def __init__(self, guider_input_fields=None): - if guider_input_fields is None: - guider_input_fields = { - "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), - "encoder_attention_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"), - "encoder_hidden_states_2": ("prompt_embeds_2", "negative_prompt_embeds_2"), - "encoder_attention_mask_2": ("prompt_embeds_mask_2", "negative_prompt_embeds_mask_2"), - } - if not isinstance(guider_input_fields, dict): - raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") - self._guider_input_fields = guider_input_fields - super().__init__() - - @property - def expected_components(self) -> list[ComponentSpec]: - return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 7.5}), - default_creation_method="from_config", - ), - ComponentSpec("transformer", HunyuanVideo15Transformer3DModel), - ] - - @property - def description(self) -> str: - return "I2V denoiser with MeanFlow timestep_r support" - - @property - def inputs(self) -> list[InputParam]: - inputs = [ - InputParam("attention_kwargs"), - InputParam("num_inference_steps", required=True, type_hint=int), - InputParam("image_embeds", type_hint=torch.Tensor), - InputParam("timesteps", required=True, type_hint=torch.Tensor), - ] - for value in self._guider_input_fields.values(): - if isinstance(value, tuple): - inputs.append(InputParam(name=value[0], required=True, type_hint=torch.Tensor)) - for neg_name in value[1:]: - inputs.append(InputParam(name=neg_name, type_hint=torch.Tensor)) - else: - inputs.append(InputParam(name=value, required=True, type_hint=torch.Tensor)) - return inputs - - @torch.no_grad() - def __call__( - self, components: HunyuanVideo15ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor - ) -> PipelineState: - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) - guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) - - timestep = t.expand(block_state.latent_model_input.shape[0]).to(block_state.latent_model_input.dtype) - - # MeanFlow: compute timestep_r - timestep_r = None - if components.transformer.config.use_meanflow: - if i == len(block_state.timesteps) - 1: - timestep_r = torch.tensor([0.0], device=timestep.device) - else: - timestep_r = block_state.timesteps[i + 1] - timestep_r = timestep_r.expand(block_state.latents.shape[0]).to(block_state.latents.dtype) - - for guider_state_batch in guider_state: - components.guider.prepare_models(components.transformer) - cond_kwargs = guider_state_batch.as_dict() - cond_kwargs = { - k: v for k, v in cond_kwargs.items() if k in self._guider_input_fields.keys() - } - - context_name = getattr(guider_state_batch, components.guider._identifier_key, None) - with components.transformer.cache_context(context_name): - guider_state_batch.noise_pred = components.transformer( - hidden_states=block_state.latent_model_input, - image_embeds=block_state.image_embeds, - timestep=timestep, - timestep_r=timestep_r, - attention_kwargs=block_state.attention_kwargs, - return_dict=False, - **cond_kwargs, - )[0] - components.guider.cleanup_models(components.transformer) - - block_state.noise_pred = components.guider(guider_state)[0] - - return components, block_state - - -class HunyuanVideo15Image2VideoDenoiseStep(HunyuanVideo15DenoiseLoopWrapper): - block_classes = [ - HunyuanVideo15LoopBeforeDenoiser, - HunyuanVideo15Image2VideoLoopDenoiser(), - HunyuanVideo15LoopAfterDenoiser, - ] - block_names = ["before_denoiser", "denoiser", "after_denoiser"] - - @property - def description(self) -> str: - return ( - "Denoise step for image-to-video with MeanFlow support.\n" - "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" - " - `HunyuanVideo15LoopBeforeDenoiser`\n" - " - `HunyuanVideo15Image2VideoLoopDenoiser`\n" - " - `HunyuanVideo15LoopAfterDenoiser`" - ) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py index d38d1cc89c20..34676c454665 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py @@ -17,15 +17,15 @@ from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance -from ...utils import logging -from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import HunyuanVideo15ModularPipeline - from ...pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5 import ( + HunyuanVideo15Pipeline, format_text_input, extract_glyph_texts, ) +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import HunyuanVideo15ModularPipeline logger = logging.get_logger(__name__) @@ -82,152 +82,7 @@ def intermediate_outputs(self) -> list[OutputParam]: OutputParam("negative_prompt_embeds_mask_2", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), ] - @staticmethod - def _get_mllm_prompt_embeds( - text_encoder, - tokenizer, - prompt, - device, - tokenizer_max_length=1000, - num_hidden_layers_to_skip=2, - # fmt: off - system_message="You are a helpful assistant. Describe the video by detailing the following aspects: \ - 1. The main content and theme of the video. \ - 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ - 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ - 4. background environment, light, style and atmosphere. \ - 5. camera angles, movements, and transitions used in the video.", - # fmt: on - crop_start=108, - ): - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = format_text_input(prompt, system_message) - - text_inputs = tokenizer.apply_chat_template( - prompt, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - padding="max_length", - max_length=tokenizer_max_length + crop_start, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids.to(device=device) - prompt_attention_mask = text_inputs.attention_mask.to(device=device) - - prompt_embeds = text_encoder( - input_ids=text_input_ids, - attention_mask=prompt_attention_mask, - output_hidden_states=True, - ).hidden_states[-(num_hidden_layers_to_skip + 1)] - - if crop_start is not None and crop_start > 0: - prompt_embeds = prompt_embeds[:, crop_start:] - prompt_attention_mask = prompt_attention_mask[:, crop_start:] - - return prompt_embeds, prompt_attention_mask - - @staticmethod - def _get_byt5_prompt_embeds(tokenizer, text_encoder, prompt, device, tokenizer_max_length=256): - prompt = [prompt] if isinstance(prompt, str) else prompt - glyph_texts = [extract_glyph_texts(p) for p in prompt] - - prompt_embeds_list = [] - prompt_embeds_mask_list = [] - - for glyph_text in glyph_texts: - if glyph_text is None: - glyph_text_embeds = torch.zeros( - (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype - ) - glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64) - else: - txt_tokens = tokenizer( - glyph_text, - padding="max_length", - max_length=tokenizer_max_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ).to(device) - - glyph_text_embeds = text_encoder( - input_ids=txt_tokens.input_ids, - attention_mask=txt_tokens.attention_mask.float(), - )[0].to(device=device) - glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) - - prompt_embeds_list.append(glyph_text_embeds) - prompt_embeds_mask_list.append(glyph_text_embeds_mask) - - return torch.cat(prompt_embeds_list, dim=0), torch.cat(prompt_embeds_mask_list, dim=0) - - @staticmethod - def encode_prompt( - components, - prompt, - device=None, - dtype=None, - batch_size=1, - num_videos_per_prompt=1, - prompt_embeds=None, - prompt_embeds_mask=None, - prompt_embeds_2=None, - prompt_embeds_mask_2=None, - ): - device = device or components._execution_device - dtype = dtype or components.text_encoder.dtype - - if prompt is None: - prompt = [""] * batch_size - - prompt = [prompt] if isinstance(prompt, str) else prompt - - if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = HunyuanVideo15TextEncoderStep._get_mllm_prompt_embeds( - tokenizer=components.tokenizer, - text_encoder=components.text_encoder, - prompt=prompt, - device=device, - tokenizer_max_length=components.tokenizer_max_length, - system_message=components.system_message, - crop_start=components.prompt_template_encode_start_idx, - ) - - if prompt_embeds_2 is None: - prompt_embeds_2, prompt_embeds_mask_2 = HunyuanVideo15TextEncoderStep._get_byt5_prompt_embeds( - tokenizer=components.tokenizer_2, - text_encoder=components.text_encoder_2, - prompt=prompt, - device=device, - tokenizer_max_length=components.tokenizer_2_max_length, - ) - - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1).view( - batch_size * num_videos_per_prompt, seq_len, -1 - ) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1).view( - batch_size * num_videos_per_prompt, seq_len - ) - - _, seq_len_2, _ = prompt_embeds_2.shape - prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1).view( - batch_size * num_videos_per_prompt, seq_len_2, -1 - ) - prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1).view( - batch_size * num_videos_per_prompt, seq_len_2 - ) - - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) - prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) - prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) - - return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 - + # Copied from HunyuanVideo15Pipeline.encode_prompt @torch.no_grad() def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -235,51 +90,58 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta dtype = components.transformer.dtype prompt = block_state.prompt + negative_prompt = block_state.negative_prompt + num_videos_per_prompt = block_state.num_videos_per_prompt + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) + elif getattr(block_state, "prompt_embeds", None) is not None: + batch_size = block_state.prompt_embeds.shape[0] else: - batch_size = getattr(block_state, "prompt_embeds", torch.empty(1)).shape[0] + batch_size = 1 + # Encode positive prompt (reuse pipeline's encode_prompt verbatim) ( block_state.prompt_embeds, block_state.prompt_embeds_mask, block_state.prompt_embeds_2, block_state.prompt_embeds_mask_2, - ) = self.encode_prompt( - components=components, + ) = HunyuanVideo15Pipeline.encode_prompt( + components, prompt=prompt, device=device, dtype=dtype, batch_size=batch_size, - num_videos_per_prompt=getattr(block_state, "num_videos_per_prompt", 1), + num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=getattr(block_state, "prompt_embeds", None), prompt_embeds_mask=getattr(block_state, "prompt_embeds_mask", None), prompt_embeds_2=getattr(block_state, "prompt_embeds_2", None), prompt_embeds_mask_2=getattr(block_state, "prompt_embeds_mask_2", None), ) - if components.guider._enabled and components.guider.num_conditions > 1: - negative_prompt = block_state.negative_prompt + # Encode negative prompt if guider needs it + if components.requires_unconditional_embeds: ( block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask, block_state.negative_prompt_embeds_2, block_state.negative_prompt_embeds_mask_2, - ) = self.encode_prompt( - components=components, + ) = HunyuanVideo15Pipeline.encode_prompt( + components, prompt=negative_prompt, device=device, dtype=dtype, batch_size=batch_size, - num_videos_per_prompt=getattr(block_state, "num_videos_per_prompt", 1), + num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=getattr(block_state, "negative_prompt_embeds", None), prompt_embeds_mask=getattr(block_state, "negative_prompt_embeds_mask", None), prompt_embeds_2=getattr(block_state, "negative_prompt_embeds_2", None), prompt_embeds_mask_2=getattr(block_state, "negative_prompt_embeds_mask_2", None), ) + # Pass batch_size downstream state.set("batch_size", batch_size) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py index f56cd5fb6119..1ae6970deeb1 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py @@ -16,13 +16,12 @@ from ..modular_pipeline import SequentialPipelineBlocks from ..modular_pipeline_utils import OutputParam from .before_denoise import ( - HunyuanVideo15Image2VideoPrepareLatentsStep, HunyuanVideo15PrepareLatentsStep, HunyuanVideo15SetTimestepsStep, HunyuanVideo15TextInputStep, ) from .decoders import HunyuanVideo15VaeDecoderStep -from .denoise import HunyuanVideo15DenoiseStep, HunyuanVideo15Image2VideoDenoiseStep +from .denoise import HunyuanVideo15DenoiseStep from .encoders import HunyuanVideo15TextEncoderStep @@ -66,42 +65,3 @@ def description(self): @property def outputs(self): return [OutputParam.template("videos")] - - -# auto_docstring -class HunyuanVideo15Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): - model_name = "hunyuan-video-1.5" - block_classes = [ - HunyuanVideo15TextInputStep, - HunyuanVideo15SetTimestepsStep, - HunyuanVideo15Image2VideoPrepareLatentsStep, - HunyuanVideo15Image2VideoDenoiseStep, - ] - block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] - - @property - def description(self): - return "I2V denoise block that takes encoded conditions, an image, and runs the denoising process." - - @property - def outputs(self): - return [OutputParam.template("latents")] - - -# auto_docstring -class HunyuanVideo15Image2VideoBlocks(SequentialPipelineBlocks): - model_name = "hunyuan-video-1.5" - block_classes = [ - HunyuanVideo15TextEncoderStep, - HunyuanVideo15Image2VideoCoreDenoiseStep, - HunyuanVideo15VaeDecoderStep, - ] - block_names = ["text_encoder", "denoise", "decode"] - - @property - def description(self): - return "Modular pipeline blocks for HunyuanVideo 1.5 image-to-video." - - @property - def outputs(self): - return [OutputParam.template("videos")] diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py index 362904b58557..090d6c99fc31 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_pipeline.py @@ -32,52 +32,36 @@ class HunyuanVideo15ModularPipeline( default_blocks_name = "HunyuanVideo15Blocks" + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline properties @property - def vae_spatial_compression_ratio(self): - if getattr(self, "vae", None) is not None: - return self.vae.spatial_compression_ratio - return 16 + def vae_scale_factor_spatial(self): + return self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 16 @property - def vae_temporal_compression_ratio(self): - if getattr(self, "vae", None) is not None: - return self.vae.temporal_compression_ratio - return 4 + def vae_scale_factor_temporal(self): + return self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 @property def num_channels_latents(self): - if getattr(self, "vae", None) is not None: - return self.vae.config.latent_channels - return 32 + return self.vae.config.latent_channels if getattr(self, "vae", None) else 32 @property def target_size(self): - if getattr(self, "transformer", None) is not None: - return self.transformer.config.target_size - return 640 + return self.transformer.config.target_size if getattr(self, "transformer", None) else 640 @property def default_aspect_ratio(self): return (16, 9) @property - def default_height(self): - from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor - processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) - h, w = processor.calculate_default_height_width( - self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size - ) - return h + def vision_num_semantic_tokens(self): + return 729 @property - def default_width(self): - from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor - processor = HunyuanVideo15ImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) - h, w = processor.calculate_default_height_width( - self.default_aspect_ratio[1], self.default_aspect_ratio[0], self.target_size - ) - return w + def vision_states_dim(self): + return self.transformer.config.image_embed_dim if getattr(self, "transformer", None) else 1152 + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.__init__ @property def tokenizer_max_length(self): return 1000 @@ -86,29 +70,23 @@ def tokenizer_max_length(self): def tokenizer_2_max_length(self): return 256 + # fmt: off @property def system_message(self): - # fmt: off return "You are a helpful assistant. Describe the video by detailing the following aspects: \ 1. The main content and theme of the video. \ 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ 4. background environment, light, style and atmosphere. \ 5. camera angles, movements, and transitions used in the video." - # fmt: on + # fmt: on @property def prompt_template_encode_start_idx(self): return 108 @property - def vision_num_semantic_tokens(self): - return 729 - - @property - def vision_states_dim(self): - if getattr(self, "transformer", None) is not None: - return self.transformer.config.image_embed_dim - return 1152 - - + def requires_unconditional_embeds(self): + if hasattr(self, "guider") and self.guider is not None: + return self.guider._enabled and self.guider.num_conditions > 1 + return False diff --git a/tests/modular_pipelines/hunyuan_video1_5/test_modular_pipeline_hunyuan_video1_5.py b/tests/modular_pipelines/hunyuan_video1_5/test_modular_pipeline_hunyuan_video1_5.py index e0107db04d33..7bc9c06cbcf7 100644 --- a/tests/modular_pipelines/hunyuan_video1_5/test_modular_pipeline_hunyuan_video1_5.py +++ b/tests/modular_pipelines/hunyuan_video1_5/test_modular_pipeline_hunyuan_video1_5.py @@ -17,7 +17,6 @@ from diffusers.modular_pipelines import ( HunyuanVideo15Blocks, - HunyuanVideo15Image2VideoBlocks, HunyuanVideo15ModularPipeline, ) @@ -68,26 +67,6 @@ def test_model_name(self): blocks = HunyuanVideo15Blocks() self.assertEqual(blocks.model_name, "hunyuan-video-1.5") - def test_i2v_import(self): - blocks = HunyuanVideo15Image2VideoBlocks() - self.assertIsNotNone(blocks) - - def test_i2v_pipeline_class(self): - blocks = HunyuanVideo15Image2VideoBlocks() - pipe = blocks.init_pipeline() - self.assertIsInstance(pipe, HunyuanVideo15ModularPipeline) - - def test_i2v_has_image_input(self): - blocks = HunyuanVideo15Image2VideoBlocks() - input_names = {inp.name for inp in blocks.inputs} - self.assertIn("image", input_names) - - def test_i2v_has_image_encoder(self): - blocks = HunyuanVideo15Image2VideoBlocks() - comp_names = {c.name for c in blocks.expected_components} - self.assertIn("image_encoder", comp_names) - self.assertIn("feature_extractor", comp_names) - def test_top_level_export(self): from diffusers import HunyuanVideo15Blocks as Top, HunyuanVideo15ModularPipeline as TopPipe From 7a46b21154b13ec8c942db28bae08f7576d53e46 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 06:16:08 -0700 Subject: [PATCH 07/14] Remove I2V exports (T2V only for now) --- src/diffusers/__init__.py | 1 - src/diffusers/modular_pipelines/__init__.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6d79e9733381..c4cab87fa533 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -453,7 +453,6 @@ "HeliosPyramidDistilledModularPipeline", "HeliosPyramidModularPipeline", "HunyuanVideo15Blocks", - "HunyuanVideo15Image2VideoBlocks", "HunyuanVideo15ModularPipeline", "QwenImageAutoBlocks", "QwenImageEditAutoBlocks", diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index ae8cb9762f21..01d8c626b3f1 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -90,7 +90,6 @@ ] _import_structure["hunyuan_video1_5"] = [ "HunyuanVideo15Blocks", - "HunyuanVideo15Image2VideoBlocks", "HunyuanVideo15ModularPipeline", ] _import_structure["z_image"] = [ @@ -145,7 +144,7 @@ QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) - from .hunyuan_video1_5 import HunyuanVideo15Blocks, HunyuanVideo15Image2VideoBlocks, HunyuanVideo15ModularPipeline + from .hunyuan_video1_5 import HunyuanVideo15Blocks, HunyuanVideo15ModularPipeline from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import ( Wan22Blocks, From 3953a25f69d5ae3c62917de81ec5c675e25d7920 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 06:30:52 -0700 Subject: [PATCH 08/14] Fix encoder: use static methods directly instead of encode_prompt --- .../hunyuan_video1_5/encoders.py | 125 ++++++++++++------ 1 file changed, 85 insertions(+), 40 deletions(-) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py index 34676c454665..aa77f0b7dc5d 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py @@ -17,11 +17,7 @@ from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance -from ...pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5 import ( - HunyuanVideo15Pipeline, - format_text_input, - extract_glyph_texts, -) +from ...pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam @@ -102,45 +98,94 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta else: batch_size = 1 - # Encode positive prompt (reuse pipeline's encode_prompt verbatim) - ( - block_state.prompt_embeds, - block_state.prompt_embeds_mask, - block_state.prompt_embeds_2, - block_state.prompt_embeds_mask_2, - ) = HunyuanVideo15Pipeline.encode_prompt( - components, - prompt=prompt, - device=device, - dtype=dtype, - batch_size=batch_size, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=getattr(block_state, "prompt_embeds", None), - prompt_embeds_mask=getattr(block_state, "prompt_embeds_mask", None), - prompt_embeds_2=getattr(block_state, "prompt_embeds_2", None), - prompt_embeds_mask_2=getattr(block_state, "prompt_embeds_mask_2", None), - ) + # Encode positive prompt - copied from HunyuanVideo15Pipeline.encode_prompt + prompt_embeds = getattr(block_state, "prompt_embeds", None) + prompt_embeds_mask = getattr(block_state, "prompt_embeds_mask", None) + prompt_embeds_2 = getattr(block_state, "prompt_embeds_2", None) + prompt_embeds_mask_2 = getattr(block_state, "prompt_embeds_mask_2", None) + + if prompt is None: + prompt = [""] * batch_size + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = HunyuanVideo15Pipeline._get_mllm_prompt_embeds( + tokenizer=components.tokenizer, + text_encoder=components.text_encoder, + prompt=prompt, + device=device, + tokenizer_max_length=components.tokenizer_max_length, + system_message=components.system_message, + crop_start=components.prompt_template_encode_start_idx, + ) - # Encode negative prompt if guider needs it - if components.requires_unconditional_embeds: - ( - block_state.negative_prompt_embeds, - block_state.negative_prompt_embeds_mask, - block_state.negative_prompt_embeds_2, - block_state.negative_prompt_embeds_mask_2, - ) = HunyuanVideo15Pipeline.encode_prompt( - components, - prompt=negative_prompt, + if prompt_embeds_2 is None: + prompt_embeds_2, prompt_embeds_mask_2 = HunyuanVideo15Pipeline._get_byt5_prompt_embeds( + tokenizer=components.tokenizer_2, + text_encoder=components.text_encoder_2, + prompt=prompt, device=device, - dtype=dtype, - batch_size=batch_size, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=getattr(block_state, "negative_prompt_embeds", None), - prompt_embeds_mask=getattr(block_state, "negative_prompt_embeds_mask", None), - prompt_embeds_2=getattr(block_state, "negative_prompt_embeds_2", None), - prompt_embeds_mask_2=getattr(block_state, "negative_prompt_embeds_mask_2", None), + tokenizer_max_length=components.tokenizer_2_max_length, ) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len) + + _, seq_len_2, _ = prompt_embeds_2.shape + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2, -1) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2) + + block_state.prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + block_state.prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) + block_state.prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) + block_state.prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) + + # Encode negative prompt if guider needs it + if components.requires_unconditional_embeds: + neg_prompt_embeds = getattr(block_state, "negative_prompt_embeds", None) + neg_prompt_embeds_mask = getattr(block_state, "negative_prompt_embeds_mask", None) + neg_prompt_embeds_2 = getattr(block_state, "negative_prompt_embeds_2", None) + neg_prompt_embeds_mask_2 = getattr(block_state, "negative_prompt_embeds_mask_2", None) + + neg_prompt = negative_prompt + if neg_prompt is None: + neg_prompt = [""] * batch_size + neg_prompt = [neg_prompt] if isinstance(neg_prompt, str) else neg_prompt + + if neg_prompt_embeds is None: + neg_prompt_embeds, neg_prompt_embeds_mask = HunyuanVideo15Pipeline._get_mllm_prompt_embeds( + tokenizer=components.tokenizer, + text_encoder=components.text_encoder, + prompt=neg_prompt, + device=device, + tokenizer_max_length=components.tokenizer_max_length, + system_message=components.system_message, + crop_start=components.prompt_template_encode_start_idx, + ) + + if neg_prompt_embeds_2 is None: + neg_prompt_embeds_2, neg_prompt_embeds_mask_2 = HunyuanVideo15Pipeline._get_byt5_prompt_embeds( + tokenizer=components.tokenizer_2, + text_encoder=components.text_encoder_2, + prompt=neg_prompt, + device=device, + tokenizer_max_length=components.tokenizer_2_max_length, + ) + + _, seq_len, _ = neg_prompt_embeds.shape + neg_prompt_embeds = neg_prompt_embeds.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len, -1) + neg_prompt_embeds_mask = neg_prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len) + + _, seq_len_2, _ = neg_prompt_embeds_2.shape + neg_prompt_embeds_2 = neg_prompt_embeds_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2, -1) + neg_prompt_embeds_mask_2 = neg_prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2) + + block_state.negative_prompt_embeds = neg_prompt_embeds.to(dtype=dtype, device=device) + block_state.negative_prompt_embeds_mask = neg_prompt_embeds_mask.to(dtype=dtype, device=device) + block_state.negative_prompt_embeds_2 = neg_prompt_embeds_2.to(dtype=dtype, device=device) + block_state.negative_prompt_embeds_mask_2 = neg_prompt_embeds_mask_2.to(dtype=dtype, device=device) + # Pass batch_size downstream state.set("batch_size", batch_size) From e8176d2eb8fbc630d8b866f383774bb85c6317c6 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 06:52:25 -0700 Subject: [PATCH 09/14] Inline all standard pipeline methods, remove runtime dependency --- .../hunyuan_video1_5/before_denoise.py | 49 ++-- .../hunyuan_video1_5/encoders.py | 263 +++++++++++++----- 2 files changed, 212 insertions(+), 100 deletions(-) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py index 8c1b65d12c63..07fca3d594c0 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py @@ -18,7 +18,6 @@ import torch from ...models import HunyuanVideo15Transformer3DModel -from ...pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor @@ -169,14 +168,13 @@ def intermediate_outputs(self) -> list[OutputParam]: OutputParam("image_embeds", type_hint=torch.Tensor), ] - # Copied from pipeline_hunyuan_video1_5.py lines 652-655, 706-725 + # Copied from pipeline_hunyuan_video1_5.py lines 652-655, 477-524, 706-725 with self->components @torch.no_grad() def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) device = components._execution_device dtype = block_state.dtype - # Calculate default height/width if not provided (line 652-655) height = block_state.height width = block_state.width if height is None and width is None: @@ -187,28 +185,33 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta batch_size = block_state.batch_size * block_state.num_videos_per_prompt num_frames = block_state.num_frames - # Copied from HunyuanVideo15Pipeline.prepare_latents (lines 477-505, 707-717) - block_state.latents = HunyuanVideo15Pipeline.prepare_latents( - components, - batch_size, - components.num_channels_latents, - height, - width, - num_frames, - dtype, - device, - block_state.generator, - block_state.latents, - ) + # Copied from HunyuanVideo15Pipeline.prepare_latents with self->components + latents = block_state.latents + if latents is not None: + latents = latents.to(device=device, dtype=dtype) + else: + shape = ( + batch_size, + components.num_channels_latents, + (num_frames - 1) // components.vae_scale_factor_temporal + 1, + int(height) // components.vae_scale_factor_spatial, + int(width) // components.vae_scale_factor_spatial, + ) + if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + latents = randn_tensor(shape, generator=block_state.generator, device=device, dtype=dtype) - # Copied from HunyuanVideo15Pipeline.prepare_cond_latents_and_mask (lines 508-524, 718) - cond_latents_concat, mask_concat = HunyuanVideo15Pipeline.prepare_cond_latents_and_mask( - components, block_state.latents, dtype, device - ) - block_state.cond_latents_concat = cond_latents_concat - block_state.mask_concat = mask_concat + block_state.latents = latents + + # Copied from HunyuanVideo15Pipeline.prepare_cond_latents_and_mask with self->components + b, c, f, h, w = latents.shape + block_state.cond_latents_concat = torch.zeros(b, c, f, h, w, dtype=dtype, device=device) + block_state.mask_concat = torch.zeros(b, 1, f, h, w, dtype=dtype, device=device) - # T2V: zero image_embeds (line 719-725) + # T2V: zero image_embeds block_state.image_embeds = torch.zeros( block_state.batch_size, components.vision_num_semantic_tokens, diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py index aa77f0b7dc5d..db4182a5c052 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + import torch from transformers import ByT5Tokenizer, Qwen2_5_VLTextModel, Qwen2TokenizerFast, T5EncoderModel from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance -from ...pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5 import HunyuanVideo15Pipeline from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam @@ -27,6 +28,111 @@ logger = logging.get_logger(__name__) +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.format_text_input +def format_text_input(prompt, system_message): + return [ + [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt + ] + + +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.extract_glyph_texts +def extract_glyph_texts(prompt): + pattern = r"\"(.*?)\"|\"(.*?)\"" + matches = re.findall(pattern, prompt) + result = [match[0] or match[1] for match in matches] + result = list(dict.fromkeys(result)) if len(result) > 1 else result + if result: + formatted_result = ". ".join([f'Text "{text}"' for text in result]) + ". " + else: + formatted_result = None + return formatted_result + + +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_mllm_prompt_embeds +def _get_mllm_prompt_embeds( + text_encoder, + tokenizer, + prompt, + device, + tokenizer_max_length=1000, + num_hidden_layers_to_skip=2, + # fmt: off + system_message="You are a helpful assistant. Describe the video by detailing the following aspects: \ + 1. The main content and theme of the video. \ + 2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects. \ + 3. Actions, events, behaviors temporal relationships, physical movement changes of the objects. \ + 4. background environment, light, style and atmosphere. \ + 5. camera angles, movements, and transitions used in the video.", + # fmt: on + crop_start=108, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = format_text_input(prompt, system_message) + + text_inputs = tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding="max_length", + max_length=tokenizer_max_length + crop_start, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + return prompt_embeds, prompt_attention_mask + + +# Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline._get_byt5_prompt_embeds +def _get_byt5_prompt_embeds(tokenizer, text_encoder, prompt, device, tokenizer_max_length=256): + prompt = [prompt] if isinstance(prompt, str) else prompt + glyph_texts = [extract_glyph_texts(p) for p in prompt] + + prompt_embeds_list = [] + prompt_embeds_mask_list = [] + + for glyph_text in glyph_texts: + if glyph_text is None: + glyph_text_embeds = torch.zeros( + (1, tokenizer_max_length, text_encoder.config.d_model), device=device, dtype=text_encoder.dtype + ) + glyph_text_embeds_mask = torch.zeros((1, tokenizer_max_length), device=device, dtype=torch.int64) + else: + txt_tokens = tokenizer( + glyph_text, + padding="max_length", + max_length=tokenizer_max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).to(device) + + glyph_text_embeds = text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask.float(), + )[0] + glyph_text_embeds = glyph_text_embeds.to(device=device) + glyph_text_embeds_mask = txt_tokens.attention_mask.to(device=device) + + prompt_embeds_list.append(glyph_text_embeds) + prompt_embeds_mask_list.append(glyph_text_embeds_mask) + + return torch.cat(prompt_embeds_list, dim=0), torch.cat(prompt_embeds_mask_list, dim=0) + + class HunyuanVideo15TextEncoderStep(ModularPipelineBlocks): model_name = "hunyuan-video-1.5" @@ -78,38 +184,29 @@ def intermediate_outputs(self) -> list[OutputParam]: OutputParam("negative_prompt_embeds_mask_2", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), ] - # Copied from HunyuanVideo15Pipeline.encode_prompt - @torch.no_grad() - def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: - block_state = self.get_block_state(state) - device = components._execution_device - dtype = components.transformer.dtype - - prompt = block_state.prompt - negative_prompt = block_state.negative_prompt - num_videos_per_prompt = block_state.num_videos_per_prompt - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - elif getattr(block_state, "prompt_embeds", None) is not None: - batch_size = block_state.prompt_embeds.shape[0] - else: - batch_size = 1 - - # Encode positive prompt - copied from HunyuanVideo15Pipeline.encode_prompt - prompt_embeds = getattr(block_state, "prompt_embeds", None) - prompt_embeds_mask = getattr(block_state, "prompt_embeds_mask", None) - prompt_embeds_2 = getattr(block_state, "prompt_embeds_2", None) - prompt_embeds_mask_2 = getattr(block_state, "prompt_embeds_mask_2", None) + # Copied from diffusers.pipelines.hunyuan_video1_5.pipeline_hunyuan_video1_5.HunyuanVideo15Pipeline.encode_prompt with self->components + @staticmethod + def encode_prompt( + components, + prompt, + device=None, + dtype=None, + batch_size=1, + num_videos_per_prompt=1, + prompt_embeds=None, + prompt_embeds_mask=None, + prompt_embeds_2=None, + prompt_embeds_mask_2=None, + ): + device = device or components._execution_device + dtype = dtype or components.text_encoder.dtype if prompt is None: prompt = [""] * batch_size prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = HunyuanVideo15Pipeline._get_mllm_prompt_embeds( + prompt_embeds, prompt_embeds_mask = _get_mllm_prompt_embeds( tokenizer=components.tokenizer, text_encoder=components.text_encoder, prompt=prompt, @@ -120,7 +217,7 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta ) if prompt_embeds_2 is None: - prompt_embeds_2, prompt_embeds_mask_2 = HunyuanVideo15Pipeline._get_byt5_prompt_embeds( + prompt_embeds_2, prompt_embeds_mask_2 = _get_byt5_prompt_embeds( tokenizer=components.tokenizer_2, text_encoder=components.text_encoder_2, prompt=prompt, @@ -136,57 +233,69 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2, -1) prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2) - block_state.prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - block_state.prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) - block_state.prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) - block_state.prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) + prompt_embeds_2 = prompt_embeds_2.to(dtype=dtype, device=device) + prompt_embeds_mask_2 = prompt_embeds_mask_2.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 + + @torch.no_grad() + def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + dtype = components.transformer.dtype + + prompt = block_state.prompt + negative_prompt = block_state.negative_prompt + num_videos_per_prompt = block_state.num_videos_per_prompt + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + elif getattr(block_state, "prompt_embeds", None) is not None: + batch_size = block_state.prompt_embeds.shape[0] + else: + batch_size = 1 + + ( + block_state.prompt_embeds, + block_state.prompt_embeds_mask, + block_state.prompt_embeds_2, + block_state.prompt_embeds_mask_2, + ) = self.encode_prompt( + components, + prompt=prompt, + device=device, + dtype=dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=getattr(block_state, "prompt_embeds", None), + prompt_embeds_mask=getattr(block_state, "prompt_embeds_mask", None), + prompt_embeds_2=getattr(block_state, "prompt_embeds_2", None), + prompt_embeds_mask_2=getattr(block_state, "prompt_embeds_mask_2", None), + ) - # Encode negative prompt if guider needs it if components.requires_unconditional_embeds: - neg_prompt_embeds = getattr(block_state, "negative_prompt_embeds", None) - neg_prompt_embeds_mask = getattr(block_state, "negative_prompt_embeds_mask", None) - neg_prompt_embeds_2 = getattr(block_state, "negative_prompt_embeds_2", None) - neg_prompt_embeds_mask_2 = getattr(block_state, "negative_prompt_embeds_mask_2", None) - - neg_prompt = negative_prompt - if neg_prompt is None: - neg_prompt = [""] * batch_size - neg_prompt = [neg_prompt] if isinstance(neg_prompt, str) else neg_prompt - - if neg_prompt_embeds is None: - neg_prompt_embeds, neg_prompt_embeds_mask = HunyuanVideo15Pipeline._get_mllm_prompt_embeds( - tokenizer=components.tokenizer, - text_encoder=components.text_encoder, - prompt=neg_prompt, - device=device, - tokenizer_max_length=components.tokenizer_max_length, - system_message=components.system_message, - crop_start=components.prompt_template_encode_start_idx, - ) - - if neg_prompt_embeds_2 is None: - neg_prompt_embeds_2, neg_prompt_embeds_mask_2 = HunyuanVideo15Pipeline._get_byt5_prompt_embeds( - tokenizer=components.tokenizer_2, - text_encoder=components.text_encoder_2, - prompt=neg_prompt, - device=device, - tokenizer_max_length=components.tokenizer_2_max_length, - ) - - _, seq_len, _ = neg_prompt_embeds.shape - neg_prompt_embeds = neg_prompt_embeds.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len, -1) - neg_prompt_embeds_mask = neg_prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len) - - _, seq_len_2, _ = neg_prompt_embeds_2.shape - neg_prompt_embeds_2 = neg_prompt_embeds_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2, -1) - neg_prompt_embeds_mask_2 = neg_prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2) - - block_state.negative_prompt_embeds = neg_prompt_embeds.to(dtype=dtype, device=device) - block_state.negative_prompt_embeds_mask = neg_prompt_embeds_mask.to(dtype=dtype, device=device) - block_state.negative_prompt_embeds_2 = neg_prompt_embeds_2.to(dtype=dtype, device=device) - block_state.negative_prompt_embeds_mask_2 = neg_prompt_embeds_mask_2.to(dtype=dtype, device=device) - - # Pass batch_size downstream + ( + block_state.negative_prompt_embeds, + block_state.negative_prompt_embeds_mask, + block_state.negative_prompt_embeds_2, + block_state.negative_prompt_embeds_mask_2, + ) = self.encode_prompt( + components, + prompt=negative_prompt, + device=device, + dtype=dtype, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=getattr(block_state, "negative_prompt_embeds", None), + prompt_embeds_mask=getattr(block_state, "negative_prompt_embeds_mask", None), + prompt_embeds_2=getattr(block_state, "negative_prompt_embeds_2", None), + prompt_embeds_mask_2=getattr(block_state, "negative_prompt_embeds_mask_2", None), + ) + state.set("batch_size", batch_size) self.set_block_state(state, block_state) From e8f99f9dc7c49ca1053be2c57554895b292a44b1 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 07:43:24 -0700 Subject: [PATCH 10/14] Add HunyuanVideo 1.5 image-to-video modular blocks --- src/diffusers/__init__.py | 1 + src/diffusers/modular_pipelines/__init__.py | 3 +- .../hunyuan_video1_5/__init__.py | 4 +- .../hunyuan_video1_5/before_denoise.py | 122 ++++++++++++++++++ .../hunyuan_video1_5/denoise.py | 119 +++++++++++++++++ .../modular_blocks_hunyuan_video1_5.py | 42 +++++- 6 files changed, 287 insertions(+), 4 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c4cab87fa533..6d79e9733381 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -453,6 +453,7 @@ "HeliosPyramidDistilledModularPipeline", "HeliosPyramidModularPipeline", "HunyuanVideo15Blocks", + "HunyuanVideo15Image2VideoBlocks", "HunyuanVideo15ModularPipeline", "QwenImageAutoBlocks", "QwenImageEditAutoBlocks", diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 01d8c626b3f1..ae8cb9762f21 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -90,6 +90,7 @@ ] _import_structure["hunyuan_video1_5"] = [ "HunyuanVideo15Blocks", + "HunyuanVideo15Image2VideoBlocks", "HunyuanVideo15ModularPipeline", ] _import_structure["z_image"] = [ @@ -144,7 +145,7 @@ QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) - from .hunyuan_video1_5 import HunyuanVideo15Blocks, HunyuanVideo15ModularPipeline + from .hunyuan_video1_5 import HunyuanVideo15Blocks, HunyuanVideo15Image2VideoBlocks, HunyuanVideo15ModularPipeline from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import ( Wan22Blocks, diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py index 3ad7d17a8357..73de8277d004 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/__init__.py @@ -21,7 +21,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["modular_blocks_hunyuan_video1_5"] = ["HunyuanVideo15Blocks"] + _import_structure["modular_blocks_hunyuan_video1_5"] = ["HunyuanVideo15Blocks", "HunyuanVideo15Image2VideoBlocks"] _import_structure["modular_pipeline"] = ["HunyuanVideo15ModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -31,7 +31,7 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .modular_blocks_hunyuan_video1_5 import HunyuanVideo15Blocks + from .modular_blocks_hunyuan_video1_5 import HunyuanVideo15Blocks, HunyuanVideo15Image2VideoBlocks from .modular_pipeline import HunyuanVideo15ModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py index 07fca3d594c0..cd6e4dee15d7 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py @@ -222,3 +222,125 @@ def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineSta self.set_block_state(state, block_state) return components, state + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator=None, sample_mode="sample"): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + raise AttributeError("Could not access latents of provided encoder_output") + + +class HunyuanVideo15Image2VideoPrepareLatentsStep(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + @property + def description(self) -> str: + return "Prepare latents, conditioning latents, mask, and image_embeds for I2V" + + @property + def expected_components(self) -> list[ComponentSpec]: + from ...models import AutoencoderKLHunyuanVideo15 + from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor + from transformers import SiglipVisionModel, SiglipImageProcessor + return [ + ComponentSpec("vae", AutoencoderKLHunyuanVideo15), + ComponentSpec( + "video_processor", + HunyuanVideo15ImageProcessor, + config=FrozenDict({"vae_scale_factor": 16}), + default_creation_method="from_config", + ), + ComponentSpec("image_encoder", SiglipVisionModel), + ComponentSpec("feature_extractor", SiglipImageProcessor), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image", required=True), + InputParam("num_frames", type_hint=int, default=121), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_videos_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam("batch_size", required=True, type_hint=int), + InputParam("dtype", type_hint=torch.dtype), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("latents", type_hint=torch.Tensor), + OutputParam("cond_latents_concat", type_hint=torch.Tensor), + OutputParam("mask_concat", type_hint=torch.Tensor), + OutputParam("image_embeds", type_hint=torch.Tensor), + ] + + # Copied from pipeline_hunyuan_video1_5_image2video.py lines 756-839 with self->components + @torch.no_grad() + def __call__(self, components: HunyuanVideo15ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + dtype = block_state.dtype + + image = block_state.image + batch_size = block_state.batch_size * block_state.num_videos_per_prompt + num_frames = block_state.num_frames + + # Resize/crop image to target resolution (line 756-759) + height, width = components.video_processor.calculate_default_height_width( + height=image.size[1], width=image.size[0], target_size=components.target_size + ) + image = components.video_processor.resize(image, height=height, width=width, resize_mode="crop") + + # Encode image with Siglip (lines 776-781) + image_encoder_dtype = next(components.image_encoder.parameters()).dtype + image_inputs = components.feature_extractor.preprocess( + images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True + ) + image_inputs = image_inputs.to(device=device, dtype=image_encoder_dtype) + image_embeds = components.image_encoder(**image_inputs).last_hidden_state + image_embeds = image_embeds.repeat(batch_size, 1, 1) + block_state.image_embeds = image_embeds.to(device=device, dtype=dtype) + + # Prepare latents (lines 818-829) + latents = block_state.latents + if latents is not None: + latents = latents.to(device=device, dtype=dtype) + else: + shape = ( + batch_size, + components.num_channels_latents, + (num_frames - 1) // components.vae_scale_factor_temporal + 1, + int(height) // components.vae_scale_factor_spatial, + int(width) // components.vae_scale_factor_spatial, + ) + latents = randn_tensor(shape, generator=block_state.generator, device=device, dtype=dtype) + block_state.latents = latents + + # Prepare cond latents and mask (lines 594-632, 831-839) + b, c, f, h, w = latents.shape + + # Copied from _get_image_latents (lines 375-388) with self->components + vae_dtype = components.vae.dtype + image_tensor = components.video_processor.preprocess( + image, height=h * components.vae_scale_factor_spatial, width=w * components.vae_scale_factor_spatial + ).to(device, dtype=vae_dtype) + image_tensor = image_tensor.unsqueeze(2) + image_latents = retrieve_latents(components.vae.encode(image_tensor), sample_mode="argmax") + image_latents = image_latents * components.vae.config.scaling_factor + + latent_condition = image_latents.repeat(batch_size, 1, f, 1, 1) + latent_condition[:, :, 1:, :, :] = 0 + block_state.cond_latents_concat = latent_condition.to(device=device, dtype=dtype) + + latent_mask = torch.zeros(b, 1, f, h, w, dtype=dtype, device=device) + latent_mask[:, :, 0, :, :] = 1.0 + block_state.mask_concat = latent_mask + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py index e66e4701b4cb..878c0072011f 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py @@ -237,3 +237,122 @@ def description(self) -> str: " - `HunyuanVideo15LoopAfterDenoiser`\n" "This block supports text-to-video tasks." ) + + +class HunyuanVideo15Image2VideoLoopDenoiser(ModularPipelineBlocks): + model_name = "hunyuan-video-1.5" + + def __init__(self, guider_input_fields=None): + if guider_input_fields is None: + guider_input_fields = { + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_attention_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"), + "encoder_hidden_states_2": ("prompt_embeds_2", "negative_prompt_embeds_2"), + "encoder_attention_mask_2": ("prompt_embeds_mask_2", "negative_prompt_embeds_mask_2"), + } + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.5}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", HunyuanVideo15Transformer3DModel), + ] + + @property + def description(self) -> str: + return "I2V denoiser with MeanFlow timestep_r support" + + @property + def inputs(self) -> list[InputParam]: + inputs = [ + InputParam("attention_kwargs"), + InputParam("num_inference_steps", required=True, type_hint=int), + InputParam("image_embeds", type_hint=torch.Tensor), + InputParam("timesteps", required=True, type_hint=torch.Tensor), + ] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + inputs.append(InputParam(name=value[0], required=True, type_hint=torch.Tensor)) + for neg_name in value[1:]: + inputs.append(InputParam(name=neg_name, type_hint=torch.Tensor)) + else: + inputs.append(InputParam(name=value, required=True, type_hint=torch.Tensor)) + return inputs + + # Copied from pipeline_hunyuan_video1_5_image2video.py lines 853-912 with self->components + @torch.no_grad() + def __call__( + self, components: HunyuanVideo15ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + timestep = t.expand(block_state.latent_model_input.shape[0]).to(block_state.latent_model_input.dtype) + + # MeanFlow timestep_r (lines 855-862) + if components.transformer.config.use_meanflow: + if i == len(block_state.timesteps) - 1: + timestep_r = torch.tensor([0.0], device=timestep.device) + else: + timestep_r = block_state.timesteps[i + 1] + timestep_r = timestep_r.expand(block_state.latents.shape[0]).to(block_state.latents.dtype) + else: + timestep_r = None + + guider_inputs = { + input_name: tuple(getattr(block_state, v) for v in value) if isinstance(value, tuple) else getattr(block_state, value) + for input_name, value in self._guider_input_fields.items() + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() + } + + context_name = getattr(guider_state_batch, components.guider._identifier_key) + with components.transformer.cache_context(context_name): + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + image_embeds=block_state.image_embeds, + timestep=timestep, + timestep_r=timestep_r, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + + components.guider.cleanup_models(components.transformer) + + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class HunyuanVideo15Image2VideoDenoiseStep(HunyuanVideo15DenoiseLoopWrapper): + block_classes = [ + HunyuanVideo15LoopBeforeDenoiser, + HunyuanVideo15Image2VideoLoopDenoiser(), + HunyuanVideo15LoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step for image-to-video with MeanFlow support.\n" + "At each iteration:\n" + " - `HunyuanVideo15LoopBeforeDenoiser`\n" + " - `HunyuanVideo15Image2VideoLoopDenoiser`\n" + " - `HunyuanVideo15LoopAfterDenoiser`" + ) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py index 1ae6970deeb1..e93765cdd929 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py @@ -16,12 +16,13 @@ from ..modular_pipeline import SequentialPipelineBlocks from ..modular_pipeline_utils import OutputParam from .before_denoise import ( + HunyuanVideo15Image2VideoPrepareLatentsStep, HunyuanVideo15PrepareLatentsStep, HunyuanVideo15SetTimestepsStep, HunyuanVideo15TextInputStep, ) from .decoders import HunyuanVideo15VaeDecoderStep -from .denoise import HunyuanVideo15DenoiseStep +from .denoise import HunyuanVideo15DenoiseStep, HunyuanVideo15Image2VideoDenoiseStep from .encoders import HunyuanVideo15TextEncoderStep @@ -65,3 +66,42 @@ def description(self): @property def outputs(self): return [OutputParam.template("videos")] + + +# auto_docstring +class HunyuanVideo15Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): + model_name = "hunyuan-video-1.5" + block_classes = [ + HunyuanVideo15TextInputStep, + HunyuanVideo15SetTimestepsStep, + HunyuanVideo15Image2VideoPrepareLatentsStep, + HunyuanVideo15Image2VideoDenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return "Denoise block for image-to-video that takes encoded conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class HunyuanVideo15Image2VideoBlocks(SequentialPipelineBlocks): + model_name = "hunyuan-video-1.5" + block_classes = [ + HunyuanVideo15TextEncoderStep, + HunyuanVideo15Image2VideoCoreDenoiseStep, + HunyuanVideo15VaeDecoderStep, + ] + block_names = ["text_encoder", "denoise", "decode"] + + @property + def description(self): + return "Modular pipeline blocks for HunyuanVideo 1.5 image-to-video." + + @property + def outputs(self): + return [OutputParam.template("videos")] From 562fa49f048d0ff282d8089757162be255dc562c Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 08:25:49 -0700 Subject: [PATCH 11/14] Fix missing FrozenDict import in before_denoise.py --- .../modular_pipelines/hunyuan_video1_5/before_denoise.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py index cd6e4dee15d7..ef7fede8557d 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py @@ -17,6 +17,7 @@ import numpy as np import torch +from ...configuration_utils import FrozenDict from ...models import HunyuanVideo15Transformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging From bd45ef6d570a0ec34700ad75afcd2a9cd53dc2c4 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 09:26:06 -0700 Subject: [PATCH 12/14] auto-generated docstrings via #auto_docstring --- .../modular_blocks_hunyuan_video1_5.py | 230 ++++++++++++++++++ 1 file changed, 230 insertions(+) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py index e93765cdd929..4cdeb68fa65e 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py @@ -31,6 +31,57 @@ # auto_docstring class HunyuanVideo15CoreDenoiseStep(SequentialPipelineBlocks): + """ + Denoise block that takes encoded conditions and runs the denoising process. + + Components: + transformer (`HunyuanVideo15Transformer3DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + TODO: Add description. + batch_size (`int`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + num_frames (`int`, *optional*, defaults to 121): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_mask (`Tensor`): + TODO: Add description. + negative_prompt_embeds_mask (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_2 (`Tensor`): + TODO: Add description. + negative_prompt_embeds_2 (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_mask_2 (`Tensor`): + TODO: Add description. + negative_prompt_embeds_mask_2 (`Tensor`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "hunyuan-video-1.5" block_classes = [ HunyuanVideo15TextInputStep, @@ -51,6 +102,69 @@ def outputs(self): # auto_docstring class HunyuanVideo15Blocks(SequentialPipelineBlocks): + """ + Modular pipeline blocks for HunyuanVideo 1.5 text-to-video. + + Components: + text_encoder (`Qwen2_5_VLTextModel`) + tokenizer (`Qwen2TokenizerFast`) + text_encoder_2 (`T5EncoderModel`) + tokenizer_2 (`ByT5Tokenizer`) + guider (`ClassifierFreeGuidance`) + transformer (`HunyuanVideo15Transformer3DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + vae (`AutoencoderKLHunyuanVideo15`) + video_processor (`HunyuanVideo15ImageProcessor`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + prompt_embeds (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_mask (`Tensor`, *optional*): + TODO: Add description. + negative_prompt_embeds (`Tensor`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_2 (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_mask_2 (`Tensor`, *optional*): + TODO: Add description. + negative_prompt_embeds_2 (`Tensor`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask_2 (`Tensor`, *optional*): + TODO: Add description. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + batch_size (`int`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + height (`int`, *optional*): + TODO: Add description. + width (`int`, *optional*): + TODO: Add description. + num_frames (`int`, *optional*, defaults to 121): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + output_type (`str`, *optional*, defaults to np): + TODO: Add description. + + Outputs: + videos (`list`): + The generated videos. + """ + model_name = "hunyuan-video-1.5" block_classes = [ HunyuanVideo15TextEncoderStep, @@ -70,6 +184,59 @@ def outputs(self): # auto_docstring class HunyuanVideo15Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): + """ + Denoise block for image-to-video that takes encoded conditions and runs the denoising process. + + Components: + transformer (`HunyuanVideo15Transformer3DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + vae (`AutoencoderKLHunyuanVideo15`) + video_processor (`HunyuanVideo15ImageProcessor`) + image_encoder (`SiglipVisionModel`) + feature_extractor (`SiglipImageProcessor`) + guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`None`, *optional*, defaults to 1): + TODO: Add description. + prompt_embeds (`Tensor`): + TODO: Add description. + batch_size (`int`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + image (`None`): + TODO: Add description. + num_frames (`int`, *optional*, defaults to 121): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + negative_prompt_embeds (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_mask (`Tensor`): + TODO: Add description. + negative_prompt_embeds_mask (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_2 (`Tensor`): + TODO: Add description. + negative_prompt_embeds_2 (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_mask_2 (`Tensor`): + TODO: Add description. + negative_prompt_embeds_mask_2 (`Tensor`, *optional*): + TODO: Add description. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "hunyuan-video-1.5" block_classes = [ HunyuanVideo15TextInputStep, @@ -90,6 +257,69 @@ def outputs(self): # auto_docstring class HunyuanVideo15Image2VideoBlocks(SequentialPipelineBlocks): + """ + Modular pipeline blocks for HunyuanVideo 1.5 image-to-video. + + Components: + text_encoder (`Qwen2_5_VLTextModel`) + tokenizer (`Qwen2TokenizerFast`) + text_encoder_2 (`T5EncoderModel`) + tokenizer_2 (`ByT5Tokenizer`) + guider (`ClassifierFreeGuidance`) + transformer (`HunyuanVideo15Transformer3DModel`) + scheduler (`FlowMatchEulerDiscreteScheduler`) + vae (`AutoencoderKLHunyuanVideo15`) + video_processor (`HunyuanVideo15ImageProcessor`) + image_encoder (`SiglipVisionModel`) + feature_extractor (`SiglipImageProcessor`) + + Inputs: + prompt (`None`, *optional*): + TODO: Add description. + negative_prompt (`None`, *optional*): + TODO: Add description. + prompt_embeds (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_mask (`Tensor`, *optional*): + TODO: Add description. + negative_prompt_embeds (`Tensor`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_2 (`Tensor`, *optional*): + TODO: Add description. + prompt_embeds_mask_2 (`Tensor`, *optional*): + TODO: Add description. + negative_prompt_embeds_2 (`Tensor`, *optional*): + TODO: Add description. + negative_prompt_embeds_mask_2 (`Tensor`, *optional*): + TODO: Add description. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + TODO: Add description. + batch_size (`int`, *optional*): + TODO: Add description. + num_inference_steps (`None`, *optional*, defaults to 50): + TODO: Add description. + sigmas (`None`, *optional*): + TODO: Add description. + image (`None`): + TODO: Add description. + num_frames (`int`, *optional*, defaults to 121): + TODO: Add description. + latents (`Tensor | NoneType`, *optional*): + TODO: Add description. + generator (`None`, *optional*): + TODO: Add description. + attention_kwargs (`None`, *optional*): + TODO: Add description. + output_type (`str`, *optional*, defaults to np): + TODO: Add description. + + Outputs: + videos (`list`): + The generated videos. + """ + model_name = "hunyuan-video-1.5" block_classes = [ HunyuanVideo15TextEncoderStep, From e439012494cb995b326201824282149cff5a0477 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 17:47:24 -0700 Subject: [PATCH 13/14] Fix ruff lint and format issues --- .../hunyuan_video1_5/before_denoise.py | 4 +++- .../hunyuan_video1_5/decoders.py | 4 +++- .../hunyuan_video1_5/denoise.py | 17 ++++++++--------- .../hunyuan_video1_5/encoders.py | 16 ++++++++++++---- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py index ef7fede8557d..4facf60854b4 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py @@ -245,9 +245,11 @@ def description(self) -> str: @property def expected_components(self) -> list[ComponentSpec]: + from transformers import SiglipImageProcessor, SiglipVisionModel + from ...models import AutoencoderKLHunyuanVideo15 from ...pipelines.hunyuan_video1_5.image_processor import HunyuanVideo15ImageProcessor - from transformers import SiglipVisionModel, SiglipImageProcessor + return [ ComponentSpec("vae", AutoencoderKLHunyuanVideo15), ComponentSpec( diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py index f5eddb16b5ed..260392d0d3cf 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py @@ -75,7 +75,9 @@ def __call__(self, components, state: PipelineState) -> PipelineState: else: latents = block_state.latents.to(components.vae.dtype) / components.vae.config.scaling_factor video = components.vae.decode(latents, return_dict=False)[0] - block_state.videos = components.video_processor.postprocess_video(video, output_type=block_state.output_type) + block_state.videos = components.video_processor.postprocess_video( + video, output_type=block_state.output_type + ) self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py index 878c0072011f..ae0fba1bb3e5 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any import torch @@ -115,7 +114,9 @@ def __call__( # Step 1: Collect model inputs guider_inputs = { - input_name: tuple(getattr(block_state, v) for v in value) if isinstance(value, tuple) else getattr(block_state, value) + input_name: tuple(getattr(block_state, v) for v in value) + if isinstance(value, tuple) + else getattr(block_state, value) for input_name, value in self._guider_input_fields.items() } @@ -129,9 +130,7 @@ def __call__( for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) - cond_kwargs = { - input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() - } + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} context_name = getattr(guider_state_batch, components.guider._identifier_key) with components.transformer.cache_context(context_name): @@ -306,7 +305,9 @@ def __call__( timestep_r = None guider_inputs = { - input_name: tuple(getattr(block_state, v) for v in value) if isinstance(value, tuple) else getattr(block_state, value) + input_name: tuple(getattr(block_state, v) for v in value) + if isinstance(value, tuple) + else getattr(block_state, value) for input_name, value in self._guider_input_fields.items() } @@ -316,9 +317,7 @@ def __call__( for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) - cond_kwargs = { - input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys() - } + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} context_name = getattr(guider_state_batch, components.guider._identifier_key) with components.transformer.cache_context(context_name): diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py index db4182a5c052..52f72db00f05 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py @@ -226,12 +226,20 @@ def encode_prompt( ) _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len) + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1).view( + batch_size * num_videos_per_prompt, seq_len, -1 + ) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_videos_per_prompt, 1).view( + batch_size * num_videos_per_prompt, seq_len + ) _, seq_len_2, _ = prompt_embeds_2.shape - prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2, -1) - prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1).view(batch_size * num_videos_per_prompt, seq_len_2) + prompt_embeds_2 = prompt_embeds_2.repeat(1, num_videos_per_prompt, 1).view( + batch_size * num_videos_per_prompt, seq_len_2, -1 + ) + prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_videos_per_prompt, 1).view( + batch_size * num_videos_per_prompt, seq_len_2 + ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds_mask = prompt_embeds_mask.to(dtype=dtype, device=device) From 330c5f677494cb0aeffb98413b00b2e7592fd67b Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Thu, 2 Apr 2026 20:01:49 -0700 Subject: [PATCH 14/14] use InputParam/OutputParam templates and fix ruff --- src/diffusers/modular_pipelines/__init__.py | 6 +- .../hunyuan_video1_5/before_denoise.py | 40 ++--- .../hunyuan_video1_5/decoders.py | 15 +- .../hunyuan_video1_5/denoise.py | 16 +- .../hunyuan_video1_5/encoders.py | 22 +-- .../modular_blocks_hunyuan_video1_5.py | 160 +++++++++--------- 6 files changed, 130 insertions(+), 129 deletions(-) diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index ae8cb9762f21..a76828291cd6 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -124,6 +124,11 @@ HeliosPyramidDistilledModularPipeline, HeliosPyramidModularPipeline, ) + from .hunyuan_video1_5 import ( + HunyuanVideo15Blocks, + HunyuanVideo15Image2VideoBlocks, + HunyuanVideo15ModularPipeline, + ) from .modular_pipeline import ( AutoPipelineBlocks, BlockState, @@ -145,7 +150,6 @@ QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) - from .hunyuan_video1_5 import HunyuanVideo15Blocks, HunyuanVideo15Image2VideoBlocks, HunyuanVideo15ModularPipeline from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import ( Wan22Blocks, diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py index 4facf60854b4..0c3a3647f878 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/before_denoise.py @@ -76,9 +76,9 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: return [ - InputParam("num_videos_per_prompt", default=1), - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), - InputParam("batch_size", type_hint=int), + InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"), + InputParam.template("prompt_embeds"), + InputParam.template("batch_size", default=None), ] @property @@ -111,8 +111,8 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: return [ - InputParam("num_inference_steps", default=50), - InputParam("sigmas"), + InputParam.template("num_inference_steps"), + InputParam.template("sigmas"), ] @property @@ -150,20 +150,20 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: return [ - InputParam("height", type_hint=int), - InputParam("width", type_hint=int), + InputParam.template("height"), + InputParam.template("width"), InputParam("num_frames", type_hint=int, default=121), - InputParam("latents", type_hint=torch.Tensor | None), - InputParam("num_videos_per_prompt", type_hint=int, default=1), - InputParam("generator"), - InputParam("batch_size", required=True, type_hint=int), - InputParam("dtype", type_hint=torch.dtype), + InputParam.template("latents"), + InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size", required=True, default=None), + InputParam.template("dtype", default=None), ] @property def intermediate_outputs(self) -> list[OutputParam]: return [ - OutputParam("latents", type_hint=torch.Tensor), + OutputParam.template("latents"), OutputParam("cond_latents_concat", type_hint=torch.Tensor), OutputParam("mask_concat", type_hint=torch.Tensor), OutputParam("image_embeds", type_hint=torch.Tensor), @@ -265,19 +265,19 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: return [ - InputParam("image", required=True), + InputParam.template("image"), InputParam("num_frames", type_hint=int, default=121), - InputParam("latents", type_hint=torch.Tensor | None), - InputParam("num_videos_per_prompt", type_hint=int, default=1), - InputParam("generator"), - InputParam("batch_size", required=True, type_hint=int), - InputParam("dtype", type_hint=torch.dtype), + InputParam.template("latents"), + InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size", required=True, default=None), + InputParam.template("dtype", default=None), ] @property def intermediate_outputs(self) -> list[OutputParam]: return [ - OutputParam("latents", type_hint=torch.Tensor), + OutputParam.template("latents"), OutputParam("cond_latents_concat", type_hint=torch.Tensor), OutputParam("mask_concat", type_hint=torch.Tensor), OutputParam("image_embeds", type_hint=torch.Tensor), diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py index 260392d0d3cf..b9b673aa06a8 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/decoders.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any -import numpy as np -import PIL import torch from ...configuration_utils import FrozenDict @@ -49,20 +46,16 @@ def description(self) -> str: return "Step that decodes the denoised latents into videos" @property - def inputs(self) -> list[tuple[str, Any]]: + def inputs(self) -> list[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor), - InputParam("output_type", default="np", type_hint=str), + InputParam.template("latents", required=True), + InputParam.template("output_type", default="np"), ] @property def intermediate_outputs(self) -> list[OutputParam]: return [ - OutputParam( - "videos", - type_hint=list[list[PIL.Image.Image]] | list[torch.Tensor] | list[np.ndarray], - description="The generated videos", - ) + OutputParam.template("videos"), ] # Copied from pipeline_hunyuan_video1_5.py lines 823-829 diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py index ae0fba1bb3e5..d1dd024b5f90 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/denoise.py @@ -43,7 +43,7 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor), + InputParam.template("latents", required=True), InputParam("cond_latents_concat", required=True, type_hint=torch.Tensor), InputParam("mask_concat", required=True, type_hint=torch.Tensor), ] @@ -92,8 +92,8 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: inputs = [ - InputParam("attention_kwargs"), - InputParam("num_inference_steps", required=True, type_hint=int), + InputParam.template("attention_kwargs"), + InputParam.template("num_inference_steps", required=True, default=None), InputParam("image_embeds", type_hint=torch.Tensor), ] for value in self._guider_input_fields.values(): @@ -194,8 +194,8 @@ def loop_expected_components(self) -> list[ComponentSpec]: @property def loop_inputs(self) -> list[InputParam]: return [ - InputParam("timesteps", required=True, type_hint=torch.Tensor), - InputParam("num_inference_steps", required=True, type_hint=int), + InputParam.template("timesteps", required=True), + InputParam.template("num_inference_steps", required=True, default=None), ] @torch.no_grad() @@ -273,10 +273,10 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: inputs = [ - InputParam("attention_kwargs"), - InputParam("num_inference_steps", required=True, type_hint=int), + InputParam.template("attention_kwargs"), + InputParam.template("num_inference_steps", required=True, default=None), InputParam("image_embeds", type_hint=torch.Tensor), - InputParam("timesteps", required=True, type_hint=torch.Tensor), + InputParam.template("timesteps", required=True), ] for value in self._guider_input_fields.values(): if isinstance(value, tuple): diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py index 52f72db00f05..cf47a772c7aa 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/encoders.py @@ -158,26 +158,26 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: return [ - InputParam("prompt"), - InputParam("negative_prompt"), - InputParam("prompt_embeds", type_hint=torch.Tensor), - InputParam("prompt_embeds_mask", type_hint=torch.Tensor), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor), - InputParam("negative_prompt_embeds_mask", type_hint=torch.Tensor), + InputParam.template("prompt", required=False), + InputParam.template("negative_prompt"), + InputParam.template("prompt_embeds", required=False), + InputParam.template("prompt_embeds_mask", required=False), + InputParam.template("negative_prompt_embeds"), + InputParam.template("negative_prompt_embeds_mask"), InputParam("prompt_embeds_2", type_hint=torch.Tensor), InputParam("prompt_embeds_mask_2", type_hint=torch.Tensor), InputParam("negative_prompt_embeds_2", type_hint=torch.Tensor), InputParam("negative_prompt_embeds_mask_2", type_hint=torch.Tensor), - InputParam("num_videos_per_prompt", type_hint=int, default=1), + InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"), ] @property def intermediate_outputs(self) -> list[OutputParam]: return [ - OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), - OutputParam("prompt_embeds_mask", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), - OutputParam("negative_prompt_embeds_mask", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), OutputParam("prompt_embeds_2", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), OutputParam("prompt_embeds_mask_2", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), OutputParam("negative_prompt_embeds_2", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"), diff --git a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py index 4cdeb68fa65e..866bafd9da6e 100644 --- a/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py +++ b/src/diffusers/modular_pipelines/hunyuan_video1_5/modular_blocks_hunyuan_video1_5.py @@ -40,28 +40,29 @@ class HunyuanVideo15CoreDenoiseStep(SequentialPipelineBlocks): guider (`ClassifierFreeGuidance`) Inputs: - num_videos_per_prompt (`None`, *optional*, defaults to 1): - TODO: Add description. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. prompt_embeds (`Tensor`): - TODO: Add description. + text embeddings used to guide the image generation. Can be generated from text_encoder step. batch_size (`int`, *optional*): - TODO: Add description. - num_inference_steps (`None`, *optional*, defaults to 50): - TODO: Add description. - sigmas (`None`, *optional*): - TODO: Add description. + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. height (`int`, *optional*): - TODO: Add description. + The height in pixels of the generated image. width (`int`, *optional*): - TODO: Add description. + The width in pixels of the generated image. num_frames (`int`, *optional*, defaults to 121): TODO: Add description. - latents (`Tensor | NoneType`, *optional*): - TODO: Add description. - generator (`None`, *optional*): - TODO: Add description. - attention_kwargs (`None`, *optional*): - TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. negative_prompt_embeds (`Tensor`, *optional*): TODO: Add description. prompt_embeds_mask (`Tensor`): @@ -117,18 +118,18 @@ class HunyuanVideo15Blocks(SequentialPipelineBlocks): video_processor (`HunyuanVideo15ImageProcessor`) Inputs: - prompt (`None`, *optional*): - TODO: Add description. - negative_prompt (`None`, *optional*): - TODO: Add description. + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. prompt_embeds (`Tensor`, *optional*): - TODO: Add description. + text embeddings used to guide the image generation. Can be generated from text_encoder step. prompt_embeds_mask (`Tensor`, *optional*): - TODO: Add description. + mask for the text embeddings. Can be generated from text_encoder step. negative_prompt_embeds (`Tensor`, *optional*): - TODO: Add description. + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. negative_prompt_embeds_mask (`Tensor`, *optional*): - TODO: Add description. + mask for the negative text embeddings. Can be generated from text_encoder step. prompt_embeds_2 (`Tensor`, *optional*): TODO: Add description. prompt_embeds_mask_2 (`Tensor`, *optional*): @@ -138,27 +139,28 @@ class HunyuanVideo15Blocks(SequentialPipelineBlocks): negative_prompt_embeds_mask_2 (`Tensor`, *optional*): TODO: Add description. num_videos_per_prompt (`int`, *optional*, defaults to 1): - TODO: Add description. + The number of images to generate per prompt. batch_size (`int`, *optional*): - TODO: Add description. - num_inference_steps (`None`, *optional*, defaults to 50): - TODO: Add description. - sigmas (`None`, *optional*): - TODO: Add description. + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. height (`int`, *optional*): - TODO: Add description. + The height in pixels of the generated image. width (`int`, *optional*): - TODO: Add description. + The width in pixels of the generated image. num_frames (`int`, *optional*, defaults to 121): TODO: Add description. - latents (`Tensor | NoneType`, *optional*): - TODO: Add description. - generator (`None`, *optional*): - TODO: Add description. - attention_kwargs (`None`, *optional*): - TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. output_type (`str`, *optional*, defaults to np): - TODO: Add description. + Output format: 'pil', 'np', 'pt'. Outputs: videos (`list`): @@ -197,26 +199,27 @@ class HunyuanVideo15Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): guider (`ClassifierFreeGuidance`) Inputs: - num_videos_per_prompt (`None`, *optional*, defaults to 1): - TODO: Add description. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. prompt_embeds (`Tensor`): - TODO: Add description. + text embeddings used to guide the image generation. Can be generated from text_encoder step. batch_size (`int`, *optional*): - TODO: Add description. - num_inference_steps (`None`, *optional*, defaults to 50): - TODO: Add description. - sigmas (`None`, *optional*): - TODO: Add description. - image (`None`): - TODO: Add description. + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. num_frames (`int`, *optional*, defaults to 121): TODO: Add description. - latents (`Tensor | NoneType`, *optional*): - TODO: Add description. - generator (`None`, *optional*): - TODO: Add description. - attention_kwargs (`None`, *optional*): - TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. negative_prompt_embeds (`Tensor`, *optional*): TODO: Add description. prompt_embeds_mask (`Tensor`): @@ -274,18 +277,18 @@ class HunyuanVideo15Image2VideoBlocks(SequentialPipelineBlocks): feature_extractor (`SiglipImageProcessor`) Inputs: - prompt (`None`, *optional*): - TODO: Add description. - negative_prompt (`None`, *optional*): - TODO: Add description. + prompt (`str`, *optional*): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. prompt_embeds (`Tensor`, *optional*): - TODO: Add description. + text embeddings used to guide the image generation. Can be generated from text_encoder step. prompt_embeds_mask (`Tensor`, *optional*): - TODO: Add description. + mask for the text embeddings. Can be generated from text_encoder step. negative_prompt_embeds (`Tensor`, *optional*): - TODO: Add description. + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. negative_prompt_embeds_mask (`Tensor`, *optional*): - TODO: Add description. + mask for the negative text embeddings. Can be generated from text_encoder step. prompt_embeds_2 (`Tensor`, *optional*): TODO: Add description. prompt_embeds_mask_2 (`Tensor`, *optional*): @@ -295,25 +298,26 @@ class HunyuanVideo15Image2VideoBlocks(SequentialPipelineBlocks): negative_prompt_embeds_mask_2 (`Tensor`, *optional*): TODO: Add description. num_videos_per_prompt (`int`, *optional*, defaults to 1): - TODO: Add description. + The number of images to generate per prompt. batch_size (`int`, *optional*): - TODO: Add description. - num_inference_steps (`None`, *optional*, defaults to 50): - TODO: Add description. - sigmas (`None`, *optional*): - TODO: Add description. - image (`None`): - TODO: Add description. + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be + generated in input step. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. num_frames (`int`, *optional*, defaults to 121): TODO: Add description. - latents (`Tensor | NoneType`, *optional*): - TODO: Add description. - generator (`None`, *optional*): - TODO: Add description. - attention_kwargs (`None`, *optional*): - TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. output_type (`str`, *optional*, defaults to np): - TODO: Add description. + Output format: 'pil', 'np', 'pt'. Outputs: videos (`list`):