diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0f74c0bbcb4a..05ba35e5f924 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -452,6 +452,9 @@ "HeliosPyramidDistilledAutoBlocks", "HeliosPyramidDistilledModularPipeline", "HeliosPyramidModularPipeline", + "LTXBlocks", + "LTXImage2VideoBlocks", + "LTXModularPipeline", "QwenImageAutoBlocks", "QwenImageEditAutoBlocks", "QwenImageEditModularPipeline", @@ -1227,6 +1230,9 @@ HeliosPyramidDistilledAutoBlocks, HeliosPyramidDistilledModularPipeline, HeliosPyramidModularPipeline, + LTXBlocks, + LTXImage2VideoBlocks, + LTXModularPipeline, QwenImageAutoBlocks, QwenImageEditAutoBlocks, QwenImageEditModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index fd9bd691ca87..c76861df96d4 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -88,6 +88,11 @@ "QwenImageLayeredModularPipeline", "QwenImageLayeredAutoBlocks", ] + _import_structure["ltx"] = [ + "LTXBlocks", + "LTXImage2VideoBlocks", + "LTXModularPipeline", + ] _import_structure["z_image"] = [ "ZImageAutoBlocks", "ZImageModularPipeline", @@ -119,6 +124,7 @@ HeliosPyramidDistilledModularPipeline, HeliosPyramidModularPipeline, ) + from .ltx import LTXBlocks, LTXImage2VideoBlocks, LTXModularPipeline from .modular_pipeline import ( AutoPipelineBlocks, BlockState, diff --git a/src/diffusers/modular_pipelines/ltx/__init__.py b/src/diffusers/modular_pipelines/ltx/__init__.py new file mode 100644 index 000000000000..6be74e6b4112 --- /dev/null +++ b/src/diffusers/modular_pipelines/ltx/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_ltx"] = ["LTXBlocks", "LTXImage2VideoBlocks"] + _import_structure["modular_pipeline"] = ["LTXModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_ltx import LTXBlocks, LTXImage2VideoBlocks + from .modular_pipeline import LTXModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/ltx/before_denoise.py b/src/diffusers/modular_pipelines/ltx/before_denoise.py new file mode 100644 index 000000000000..47344b55ea0d --- /dev/null +++ b/src/diffusers/modular_pipelines/ltx/before_denoise.py @@ -0,0 +1,411 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import numpy as np +import torch + +from ...models import LTXVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import LTXModularPipeline + + +logger = logging.get_logger(__name__) + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed.") + if timesteps is not None: + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents +def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape + # [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, + # dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + +# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents +def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 +) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + +class LTXTextInputStep(ModularPipelineBlocks): + model_name = "ltx" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` and `num_videos_per_prompt`" + ) + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", LTXVideoTransformer3DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"), + InputParam.template("prompt_embeds", required=True), + InputParam.template("prompt_embeds_mask", name="prompt_attention_mask"), + InputParam.template("negative_prompt_embeds"), + InputParam.template("negative_prompt_embeds_mask", name="negative_prompt_attention_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("batch_size", type_hint=int), + OutputParam("dtype", type_hint=torch.dtype), + ] + + @torch.no_grad() + def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + num_videos = block_state.num_videos_per_prompt + + # Repeat prompt_embeds for num_videos_per_prompt + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, num_videos, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * num_videos, seq_len, -1) + + if block_state.prompt_attention_mask is not None: + block_state.prompt_attention_mask = block_state.prompt_attention_mask.repeat(num_videos, 1) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, num_videos, 1) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * num_videos, seq_len, -1 + ) + + if block_state.negative_prompt_attention_mask is not None: + block_state.negative_prompt_attention_mask = block_state.negative_prompt_attention_mask.repeat( + num_videos, 1 + ) + + self.set_block_state(state, block_state) + return components, state + + +class LTXSetTimestepsStep(ModularPipelineBlocks): + model_name = "ltx" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam.template("timesteps"), + InputParam.template("sigmas"), + InputParam.template("height", default=512), + InputParam.template("width", default=704), + InputParam("num_frames", type_hint=int, default=161), + InputParam("frame_rate", type_hint=int, default=25), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor), + OutputParam("num_inference_steps", type_hint=int), + OutputParam("rope_interpolation_scale", type_hint=tuple), + ] + + @torch.no_grad() + def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + height = block_state.height + width = block_state.width + num_frames = block_state.num_frames + frame_rate = block_state.frame_rate + + latent_num_frames = (num_frames - 1) // components.vae_temporal_compression_ratio + 1 + latent_height = height // components.vae_spatial_compression_ratio + latent_width = width // components.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + + custom_timesteps = block_state.timesteps + sigmas = block_state.sigmas + + if custom_timesteps is not None: + # User provided custom timesteps, don't compute sigmas + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + custom_timesteps, + ) + else: + if sigmas is None: + sigmas = np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) + + mu = calculate_shift( + video_sequence_length, + components.scheduler.config.get("base_image_seq_len", 256), + components.scheduler.config.get("max_image_seq_len", 4096), + components.scheduler.config.get("base_shift", 0.5), + components.scheduler.config.get("max_shift", 1.15), + ) + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + block_state.rope_interpolation_scale = ( + components.vae_temporal_compression_ratio / frame_rate, + components.vae_spatial_compression_ratio, + components.vae_spatial_compression_ratio, + ) + + self.set_block_state(state, block_state) + return components, state + + +class LTXPrepareLatentsStep(ModularPipelineBlocks): + model_name = "ltx" + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-video generation process" + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("height", default=512), + InputParam.template("width", default=704), + InputParam("num_frames", type_hint=int, default=161), + InputParam.template("latents"), + InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size", required=True), + InputParam.template("dtype"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("latents", type_hint=torch.Tensor), + ] + + @torch.no_grad() + def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + batch_size = block_state.batch_size * block_state.num_videos_per_prompt + num_channels_latents = components.transformer.config.in_channels + + if block_state.latents is not None: + block_state.latents = block_state.latents.to(device=device, dtype=torch.float32) + else: + height = block_state.height // components.vae_spatial_compression_ratio + width = block_state.width // components.vae_spatial_compression_ratio + num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + block_state.latents = randn_tensor( + shape, generator=block_state.generator, device=device, dtype=torch.float32 + ) + block_state.latents = _pack_latents( + block_state.latents, + components.transformer_spatial_patch_size, + components.transformer_temporal_patch_size, + ) + + self.set_block_state(state, block_state) + return components, state + + +class LTXImage2VideoPrepareLatentsStep(ModularPipelineBlocks): + model_name = "ltx" + + @property + def description(self) -> str: + return ( + "Prepare latents step for image-to-video: takes pre-encoded image latents and creates a conditioning mask" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam("image_latents", type_hint=torch.Tensor, required=True), + InputParam.template("height", default=512), + InputParam.template("width", default=704), + InputParam("num_frames", type_hint=int, default=161), + InputParam.template("latents"), + InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"), + InputParam.template("generator"), + InputParam.template("batch_size", required=True), + InputParam.template("dtype"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("latents", type_hint=torch.Tensor), + OutputParam("conditioning_mask", type_hint=torch.Tensor), + ] + + @torch.no_grad() + def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + batch_size = block_state.batch_size * block_state.num_videos_per_prompt + + height = block_state.height // components.vae_spatial_compression_ratio + width = block_state.width // components.vae_spatial_compression_ratio + num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1 + + mask_shape = (batch_size, 1, num_frames, height, width) + + if block_state.latents is not None: + conditioning_mask = block_state.latents.new_zeros(mask_shape) + conditioning_mask[:, :, 0] = 1.0 + conditioning_mask = _pack_latents( + conditioning_mask, + components.transformer_spatial_patch_size, + components.transformer_temporal_patch_size, + ).squeeze(-1) + block_state.latents = block_state.latents.to(device=device, dtype=torch.float32) + block_state.conditioning_mask = conditioning_mask + self.set_block_state(state, block_state) + return components, state + + init_latents = block_state.image_latents.to(device=device, dtype=torch.float32) + if init_latents.shape[0] < batch_size: + init_latents = init_latents.repeat_interleave(batch_size // init_latents.shape[0], dim=0) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + + actual_mask_shape = ( + init_latents.shape[0], + 1, + init_latents.shape[2], + init_latents.shape[3], + init_latents.shape[4], + ) + conditioning_mask = torch.zeros(actual_mask_shape, device=device, dtype=torch.float32) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(init_latents.shape, generator=block_state.generator, device=device, dtype=torch.float32) + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + + conditioning_mask = _pack_latents( + conditioning_mask, + components.transformer_spatial_patch_size, + components.transformer_temporal_patch_size, + ).squeeze(-1) + latents = _pack_latents( + latents, + components.transformer_spatial_patch_size, + components.transformer_temporal_patch_size, + ) + + block_state.latents = latents + block_state.conditioning_mask = conditioning_mask + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/ltx/decoders.py b/src/diffusers/modular_pipelines/ltx/decoders.py new file mode 100644 index 000000000000..7524d6f7f67d --- /dev/null +++ b/src/diffusers/modular_pipelines/ltx/decoders.py @@ -0,0 +1,147 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLLTXVideo +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) + + +# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents +def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 +) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, + # D is the effective feature dimensions) are unpacked and reshaped into a video tensor + # of shape [B, C, F, H, W]. This is the inverse operation of what happens in the + # `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + +# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents +def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 +) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + +class LTXVaeDecoderStep(ModularPipelineBlocks): + model_name = "ltx" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLLTXVideo), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 32}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into videos" + + @property + def inputs(self) -> list[tuple[str, Any]]: + return [ + InputParam.template("latents", required=True), + InputParam.template("output_type", default="np"), + InputParam.template("height", default=512), + InputParam.template("width", default=704), + InputParam("num_frames", type_hint=int, default=161), + InputParam("decode_timestep", default=0.0), + InputParam("decode_noise_scale", default=None), + InputParam.template("generator"), + InputParam.template("batch_size"), + InputParam.template("dtype", required=True), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("videos")] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + latents = block_state.latents + + height = block_state.height + width = block_state.width + num_frames = block_state.num_frames + + latent_num_frames = (num_frames - 1) // components.vae_temporal_compression_ratio + 1 + latent_height = height // components.vae_spatial_compression_ratio + latent_width = width // components.vae_spatial_compression_ratio + + latents = _unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + components.transformer_spatial_patch_size, + components.transformer_temporal_patch_size, + ) + latents = _denormalize_latents(latents, vae.latents_mean, vae.latents_std, vae.config.scaling_factor) + latents = latents.to(block_state.dtype) + + if not vae.config.timestep_conditioning: + timestep = None + else: + device = latents.device + batch_size = block_state.batch_size + decode_timestep = block_state.decode_timestep + decode_noise_scale = block_state.decode_noise_scale + + noise = randn_tensor(latents.shape, generator=block_state.generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.to(vae.dtype) + video = vae.decode(latents, timestep, return_dict=False)[0] + block_state.videos = components.video_processor.postprocess_video(video, output_type=block_state.output_type) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/ltx/denoise.py b/src/diffusers/modular_pipelines/ltx/denoise.py new file mode 100644 index 000000000000..e8f72ec4a477 --- /dev/null +++ b/src/diffusers/modular_pipelines/ltx/denoise.py @@ -0,0 +1,513 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import LTXVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam +from .modular_pipeline import LTXModularPipeline + + +logger = logging.get_logger(__name__) + + +# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents +def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape + # [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, + # dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + +# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents +def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 +) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, + # D is the effective feature dimensions) are unpacked and reshaped into a video tensor + # of shape [B, C, F, H, W]. This is the inverse operation of what happens in the + # `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + +class LTXLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "ltx" + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `LTXDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("latents", required=True), + InputParam.template("dtype", required=True), + ] + + @torch.no_grad() + def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = block_state.latents.to(block_state.dtype) + return components, block_state + + +class LTXLoopDenoiser(ModularPipelineBlocks): + model_name = "ltx" + + def __init__( + self, + guider_input_fields: dict[str, Any] | None = None, + ): + if guider_input_fields is None: + guider_input_fields = { + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"), + } + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 3.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", LTXVideoTransformer3DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoises the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `LTXDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + inputs = [ + InputParam.template("attention_kwargs"), + InputParam.template("num_inference_steps", required=True), + InputParam("rope_interpolation_scale", type_hint=tuple), + InputParam.template("height"), + InputParam.template("width"), + InputParam("num_frames", type_hint=int), + ] + guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.extend(value) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) + return inputs + + @torch.no_grad() + def __call__( + self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1 + latent_height = block_state.height // components.vae_spatial_compression_ratio + latent_width = block_state.width // components.vae_spatial_compression_ratio + + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = { + k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + context_name = getattr(guider_state_batch, components.guider._identifier_key, None) + with components.transformer.cache_context(context_name): + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=block_state.rope_interpolation_scale, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class LTXLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "ltx" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that updates the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `LTXDenoiseLoopWrapper`)" + ) + + @torch.no_grad() + def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class LTXDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "ltx" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoises the latents over `timesteps`. " + "The specific steps within each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", LTXVideoTransformer3DModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam.template("timesteps", required=True), + InputParam.template("num_inference_steps", required=True), + ] + + @torch.no_grad() + def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + return components, state + + +class LTXDenoiseStep(LTXDenoiseLoopWrapper): + block_classes = [ + LTXLoopBeforeDenoiser, + LTXLoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"), + } + ), + LTXLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoises the latents.\n" + "Its loop logic is defined in `LTXDenoiseLoopWrapper.__call__` method.\n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `LTXLoopBeforeDenoiser`\n" + " - `LTXLoopDenoiser`\n" + " - `LTXLoopAfterDenoiser`\n" + "This block supports text-to-video tasks." + ) + + +class LTXImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "ltx" + + @property + def description(self) -> str: + return ( + "Step within the i2v denoising loop that prepares the latent input and modulates " + "the timestep with the conditioning mask." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("latents", required=True), + InputParam("conditioning_mask", required=True, type_hint=torch.Tensor), + InputParam.template("dtype", required=True), + ] + + @torch.no_grad() + def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = block_state.latents.to(block_state.dtype) + block_state.timestep_adjusted = t.expand(block_state.latent_model_input.shape[0]).unsqueeze(-1) * ( + 1 - block_state.conditioning_mask + ) + return components, block_state + + +class LTXImage2VideoLoopDenoiser(ModularPipelineBlocks): + model_name = "ltx" + + def __init__( + self, + guider_input_fields: dict[str, Any] | None = None, + ): + if guider_input_fields is None: + guider_input_fields = { + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"), + } + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> list[ComponentSpec]: + from ...configuration_utils import FrozenDict + from ...guiders import ClassifierFreeGuidance + + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 3.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", LTXVideoTransformer3DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the i2v denoising loop that denoises the latents with guidance " + "using timestep modulated by the conditioning mask." + ) + + @property + def inputs(self) -> list[tuple[str, Any]]: + inputs = [ + InputParam.template("attention_kwargs"), + InputParam.template("num_inference_steps", required=True), + InputParam("rope_interpolation_scale", type_hint=tuple), + InputParam.template("height"), + InputParam.template("width"), + InputParam("num_frames", type_hint=int), + ] + guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.extend(value) + else: + guider_input_names.append(value) + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) + return inputs + + @torch.no_grad() + def __call__( + self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1 + latent_height = block_state.height // components.vae_spatial_compression_ratio + latent_width = block_state.width // components.vae_spatial_compression_ratio + + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = { + k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + context_name = getattr(guider_state_batch, components.guider._identifier_key, None) + with components.transformer.cache_context(context_name): + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latent_model_input, + timestep=block_state.timestep_adjusted, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=block_state.rope_interpolation_scale, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + components.guider.cleanup_models(components.transformer) + + block_state.noise_pred = components.guider(guider_state)[0] + + return components, block_state + + +class LTXImage2VideoLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "ltx" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] + + @property + def description(self) -> str: + return ( + "Step within the i2v denoising loop that updates the latents, " + "applying the scheduler step only to frames after the first (conditioned) frame." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("height"), + InputParam.template("width"), + InputParam("num_frames", type_hint=int), + ] + + @torch.no_grad() + def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1 + latent_height = block_state.height // components.vae_spatial_compression_ratio + latent_width = block_state.width // components.vae_spatial_compression_ratio + + noise_pred = _unpack_latents( + block_state.noise_pred, + latent_num_frames, + latent_height, + latent_width, + components.transformer_spatial_patch_size, + components.transformer_temporal_patch_size, + ) + latents = _unpack_latents( + block_state.latents, + latent_num_frames, + latent_height, + latent_width, + components.transformer_spatial_patch_size, + components.transformer_temporal_patch_size, + ) + + noise_pred = noise_pred[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = components.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + block_state.latents = _pack_latents( + latents, + components.transformer_spatial_patch_size, + components.transformer_temporal_patch_size, + ) + + return components, block_state + + +class LTXImage2VideoDenoiseStep(LTXDenoiseLoopWrapper): + block_classes = [ + LTXImage2VideoLoopBeforeDenoiser, + LTXImage2VideoLoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"), + } + ), + LTXImage2VideoLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step for image-to-video that iteratively denoises the latents.\n" + "The first frame is kept fixed via a conditioning mask.\n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `LTXImage2VideoLoopBeforeDenoiser`\n" + " - `LTXImage2VideoLoopDenoiser`\n" + " - `LTXImage2VideoLoopAfterDenoiser`" + ) diff --git a/src/diffusers/modular_pipelines/ltx/encoders.py b/src/diffusers/modular_pipelines/ltx/encoders.py new file mode 100644 index 000000000000..ec76a86cf2f1 --- /dev/null +++ b/src/diffusers/modular_pipelines/ltx/encoders.py @@ -0,0 +1,274 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import AutoencoderKLLTXVideo +from ...utils import logging +from ...video_processor import VideoProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import LTXModularPipeline + + +logger = logging.get_logger(__name__) + + +def _get_t5_prompt_embeds( + components, + prompt: str | list[str], + max_sequence_length: int, + device: torch.device, + dtype: torch.dtype, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = components.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + prompt_embeds = components.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + +class LTXTextEncoderStep(ModularPipelineBlocks): + model_name = "ltx" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings to guide the video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", T5EncoderModel), + ComponentSpec("tokenizer", T5TokenizerFast), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 3.0}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam.template("max_sequence_length", default=128), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask", name="prompt_attention_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask", name="negative_prompt_attention_mask"), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @staticmethod + def encode_prompt( + components, + prompt: str, + device: torch.device | None = None, + prepare_unconditional_embeds: bool = True, + negative_prompt: str | None = None, + max_sequence_length: int = 128, + ): + device = device or components._execution_device + dtype = components.text_encoder.dtype + + if not isinstance(prompt, list): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds, prompt_attention_mask = _get_t5_prompt_embeds( + components=components, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if prepare_unconditional_embeds: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = _get_t5_prompt_embeds( + components=components, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + @torch.no_grad() + def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.device = components._execution_device + + ( + block_state.prompt_embeds, + block_state.prompt_attention_mask, + block_state.negative_prompt_embeds, + block_state.negative_prompt_attention_mask, + ) = self.encode_prompt( + components=components, + prompt=block_state.prompt, + device=block_state.device, + prepare_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + max_sequence_length=block_state.max_sequence_length, + ) + + self.set_block_state(state, block_state) + return components, state + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents +def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 +) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + +class LTXVaeEncoderStep(ModularPipelineBlocks): + model_name = "ltx" + + @property + def description(self) -> str: + return "VAE Encoder step that encodes an input image into latent space for image-to-video generation" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLLTXVideo), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 32}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("image", required=True), + InputParam.template("height", default=512), + InputParam.template("width", default=704), + InputParam.template("generator"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "image_latents", + type_hint=torch.Tensor, + description="Encoded image latents from the VAE encoder", + ), + ] + + @torch.no_grad() + def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + device = components._execution_device + + image = block_state.image + if not isinstance(image, torch.Tensor): + image = components.video_processor.preprocess(image, height=block_state.height, width=block_state.width) + image = image.to(device=device, dtype=torch.float32) + + vae_dtype = components.vae.dtype + + num_images = image.shape[0] + if isinstance(block_state.generator, list): + init_latents = [ + retrieve_latents( + components.vae.encode(image[i].unsqueeze(0).unsqueeze(2).to(vae_dtype)), + block_state.generator[i], + ) + for i in range(num_images) + ] + else: + init_latents = [ + retrieve_latents( + components.vae.encode(img.unsqueeze(0).unsqueeze(2).to(vae_dtype)), + block_state.generator, + ) + for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(torch.float32) + block_state.image_latents = _normalize_latents( + init_latents, components.vae.latents_mean, components.vae.latents_std + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/ltx/modular_blocks_ltx.py b/src/diffusers/modular_pipelines/ltx/modular_blocks_ltx.py new file mode 100644 index 000000000000..76c69e3f0fdb --- /dev/null +++ b/src/diffusers/modular_pipelines/ltx/modular_blocks_ltx.py @@ -0,0 +1,298 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import OutputParam +from .before_denoise import ( + LTXImage2VideoPrepareLatentsStep, + LTXPrepareLatentsStep, + LTXSetTimestepsStep, + LTXTextInputStep, +) +from .decoders import LTXVaeDecoderStep +from .denoise import LTXDenoiseStep, LTXImage2VideoDenoiseStep +from .encoders import LTXTextEncoderStep, LTXVaeEncoderStep + + +logger = logging.get_logger(__name__) + + +# auto_docstring +class LTXCoreDenoiseStep(SequentialPipelineBlocks): + """ + Denoise block that takes encoded conditions and runs the denoising process. + + Components: + transformer (`LTXVideoTransformer3DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider + (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_attention_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_attention_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 704): + The width in pixels of the generated image. + num_frames (`int`, *optional*, defaults to 161): + TODO: Add description. + frame_rate (`int`, *optional*, defaults to 25): + TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "ltx" + block_classes = [ + LTXTextInputStep, + LTXSetTimestepsStep, + LTXPrepareLatentsStep, + LTXDenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return "Denoise block that takes encoded conditions and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class LTXImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): + """ + Denoise block for image-to-video that takes encoded conditions and an image, and runs the denoising process. + + Components: + transformer (`LTXVideoTransformer3DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) vae + (`AutoencoderKLLTXVideo`) video_processor (`VideoProcessor`) guider (`ClassifierFreeGuidance`) + + Inputs: + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_attention_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_attention_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 704): + The width in pixels of the generated image. + num_frames (`int`, *optional*, defaults to 161): + TODO: Add description. + frame_rate (`int`, *optional*, defaults to 25): + TODO: Add description. + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "ltx" + block_classes = [ + LTXTextInputStep, + LTXSetTimestepsStep, + LTXVaeEncoderStep, + LTXImage2VideoPrepareLatentsStep, + LTXImage2VideoDenoiseStep, + ] + block_names = ["input", "set_timesteps", "vae_encoder", "prepare_latents", "denoise"] + + @property + def description(self): + return "Denoise block for image-to-video that takes encoded conditions and an image, and runs the denoising process." + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +# auto_docstring +class LTXBlocks(SequentialPipelineBlocks): + """ + Modular pipeline blocks for LTX Video text-to-video. + + Components: + text_encoder (`T5EncoderModel`) tokenizer (`T5TokenizerFast`) guider (`ClassifierFreeGuidance`) transformer + (`LTXVideoTransformer3DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) vae (`AutoencoderKLLTXVideo`) + video_processor (`VideoProcessor`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 128): + Maximum sequence length for prompt encoding. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 704): + The width in pixels of the generated image. + num_frames (`int`, *optional*, defaults to 161): + TODO: Add description. + frame_rate (`int`, *optional*, defaults to 25): + TODO: Add description. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + output_type (`str`, *optional*, defaults to np): + Output format: 'pil', 'np', 'pt'. + decode_timestep (`None`, *optional*, defaults to 0.0): + TODO: Add description. + decode_noise_scale (`None`, *optional*): + TODO: Add description. + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "ltx" + block_classes = [ + LTXTextEncoderStep, + LTXCoreDenoiseStep, + LTXVaeDecoderStep, + ] + block_names = ["text_encoder", "denoise", "decode"] + + @property + def description(self): + return "Modular pipeline blocks for LTX Video text-to-video." + + @property + def outputs(self): + return [OutputParam.template("videos")] + + +# auto_docstring +class LTXImage2VideoBlocks(SequentialPipelineBlocks): + """ + Modular pipeline blocks for LTX Video image-to-video. + + Components: + text_encoder (`T5EncoderModel`) tokenizer (`T5TokenizerFast`) guider (`ClassifierFreeGuidance`) transformer + (`LTXVideoTransformer3DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) vae (`AutoencoderKLLTXVideo`) + video_processor (`VideoProcessor`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 128): + Maximum sequence length for prompt encoding. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 704): + The width in pixels of the generated image. + num_frames (`int`, *optional*, defaults to 161): + TODO: Add description. + frame_rate (`int`, *optional*, defaults to 25): + TODO: Add description. + image (`Image | list`): + Reference image(s) for denoising. Can be a single image or list of images. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + attention_kwargs (`dict`, *optional*): + Additional kwargs for attention processors. + output_type (`str`, *optional*, defaults to np): + Output format: 'pil', 'np', 'pt'. + decode_timestep (`None`, *optional*, defaults to 0.0): + TODO: Add description. + decode_noise_scale (`None`, *optional*): + TODO: Add description. + + Outputs: + videos (`list`): + The generated videos. + """ + + model_name = "ltx" + block_classes = [ + LTXTextEncoderStep, + LTXImage2VideoCoreDenoiseStep, + LTXVaeDecoderStep, + ] + block_names = ["text_encoder", "denoise", "decode"] + + @property + def description(self): + return "Modular pipeline blocks for LTX Video image-to-video." + + @property + def outputs(self): + return [OutputParam.template("videos")] diff --git a/src/diffusers/modular_pipelines/ltx/modular_pipeline.py b/src/diffusers/modular_pipelines/ltx/modular_pipeline.py new file mode 100644 index 000000000000..3cce6845396b --- /dev/null +++ b/src/diffusers/modular_pipelines/ltx/modular_pipeline.py @@ -0,0 +1,64 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import LTXVideoLoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) + + +class LTXModularPipeline( + ModularPipeline, + LTXVideoLoraLoaderMixin, +): + """ + A ModularPipeline for LTX Video. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "LTXBlocks" + + @property + def vae_spatial_compression_ratio(self): + if getattr(self, "vae", None) is not None: + return self.vae.spatial_compression_ratio + return 32 + + @property + def vae_temporal_compression_ratio(self): + if getattr(self, "vae", None) is not None: + return self.vae.temporal_compression_ratio + return 8 + + @property + def transformer_spatial_patch_size(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.patch_size + return 1 + + @property + def transformer_temporal_patch_size(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.patch_size_t + return 1 + + @property + def requires_unconditional_embeds(self): + if hasattr(self, "guider") and self.guider is not None: + return self.guider._enabled and self.guider.num_conditions > 1 + return False diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 9cd2f9f5c6ae..ace89f0d6f91 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -132,6 +132,7 @@ def _helios_pyramid_map_fn(config_dict=None): ("z-image", _create_default_map_fn("ZImageModularPipeline")), ("helios", _create_default_map_fn("HeliosModularPipeline")), ("helios-pyramid", _helios_pyramid_map_fn), + ("ltx", _create_default_map_fn("LTXModularPipeline")), ] ) diff --git a/tests/modular_pipelines/ltx/__init__.py b/tests/modular_pipelines/ltx/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/ltx/test_modular_pipeline_ltx.py b/tests/modular_pipelines/ltx/test_modular_pipeline_ltx.py new file mode 100644 index 000000000000..00e68d26fdee --- /dev/null +++ b/tests/modular_pipelines/ltx/test_modular_pipeline_ltx.py @@ -0,0 +1,49 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from diffusers.modular_pipelines import LTXBlocks, LTXModularPipeline + +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +class TestLTXModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = LTXModularPipeline + pipeline_blocks_class = LTXBlocks + pretrained_model_name_or_path = "akshan-main/tiny-ltx-modular-pipe" + + params = frozenset(["prompt", "height", "width", "num_frames"]) + batch_params = frozenset(["prompt"]) + optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"]) + output_name = "videos" + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "height": 32, + "width": 32, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + @pytest.mark.skip(reason="num_videos_per_prompt") + def test_num_images_per_prompt(self): + pass