diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7582a56505f7..885e2aa27181 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -488,6 +488,8 @@ title: AudioLDM 2 - local: api/pipelines/stable_audio title: Stable Audio + - local: api/pipelines/longcat_audio_dit + title: LongCat-AudioDiT title: Audio - sections: - local: api/pipelines/animatediff diff --git a/docs/source/en/api/pipelines/longcat_audio_dit.md b/docs/source/en/api/pipelines/longcat_audio_dit.md new file mode 100644 index 000000000000..86488416727e --- /dev/null +++ b/docs/source/en/api/pipelines/longcat_audio_dit.md @@ -0,0 +1,61 @@ + + +# LongCat-AudioDiT + +LongCat-AudioDiT is a text-to-audio diffusion model from Meituan LongCat. The diffusers integration exposes a standard [`DiffusionPipeline`] interface for text-conditioned audio generation. + +This pipeline supports loading the original flat LongCat checkpoint layout from either a local directory or a Hugging Face Hub repository containing: + +- `config.json` +- `model.safetensors` + +The loader builds the text encoder, transformer, and VAE from `config.json`, restores component weights from `model.safetensors`, and ties the shared UMT5 embedding when needed. + +This pipeline was adapted from the LongCat-AudioDiT reference implementation: https://github.com/meituan-longcat/LongCat-AudioDiT + +## Usage + +```py +import soundfile as sf +import torch +from diffusers import LongCatAudioDiTPipeline + +pipeline = LongCatAudioDiTPipeline.from_pretrained( + "meituan-longcat/LongCat-AudioDiT-1B", + torch_dtype=torch.float16, +) +pipeline = pipeline.to("cuda") + +audio = pipeline( + prompt="A calm ocean wave ambience with soft wind in the background.", + audio_end_in_s=5.0, + num_inference_steps=16, + guidance_scale=4.0, + output_type="pt", +).audios + +output = audio[0, 0].float().cpu().numpy() +sf.write("longcat.wav", output, pipeline.sample_rate) +``` + +## Tips + +- `audio_end_in_s` is the most direct way to control output duration. +- `output_type="pt"` returns a PyTorch tensor shaped `(batch, channels, samples)`. + +## LongCatAudioDiTPipeline + +[[autodoc]] LongCatAudioDiTPipeline + - all + - __call__ + - from_pretrained diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index c3e493c63d6a..2d5c4ff74039 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -29,6 +29,7 @@ The table below lists all the pipelines currently available in πŸ€— Diffusers an |---|---| | [AnimateDiff](animatediff) | text2video | | [AudioLDM2](audioldm2) | text2audio | +| [LongCat-AudioDiT](longcat_audio_dit) | text2audio | | [AuraFlow](aura_flow) | text2image | | [Bria 3.2](bria_3_2) | text2image | | [CogVideoX](cogvideox) | text2video | diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0f74c0bbcb4a..b48a7f0a1c46 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -212,6 +212,7 @@ "AutoencoderKLTemporalDecoder", "AutoencoderKLWan", "AutoencoderOobleck", + "LongCatAudioDiTVae", "AutoencoderRAE", "AutoencoderTiny", "AutoencoderVidTok", @@ -253,6 +254,7 @@ "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", "LongCatImageTransformer2DModel", + "LongCatAudioDiTTransformer", "LTX2VideoTransformer3DModel", "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", @@ -594,6 +596,7 @@ "LLaDA2PipelineOutput", "LongCatImageEditPipeline", "LongCatImagePipeline", + "LongCatAudioDiTPipeline", "LTX2ConditionPipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline", @@ -1007,6 +1010,7 @@ AutoencoderKLTemporalDecoder, AutoencoderKLWan, AutoencoderOobleck, + LongCatAudioDiTVae, AutoencoderRAE, AutoencoderTiny, AutoencoderVidTok, @@ -1048,6 +1052,7 @@ Kandinsky5Transformer3DModel, LatteTransformer3DModel, LongCatImageTransformer2DModel, + LongCatAudioDiTTransformer, LTX2VideoTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, @@ -1365,6 +1370,7 @@ LLaDA2PipelineOutput, LongCatImageEditPipeline, LongCatImagePipeline, + LongCatAudioDiTPipeline, LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7ded56049833..2b24b53a7035 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -51,6 +51,7 @@ _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] + _import_structure["autoencoders.autoencoder_longcat_audio_dit"] = ["LongCatAudioDiTVae"] _import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.autoencoder_vidtok"] = ["AutoencoderVidTok"] @@ -112,6 +113,7 @@ _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"] _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"] + _import_structure["transformers.transformer_longcat_audio_dit"] = ["LongCatAudioDiTTransformer"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 609146ec340d..803b27285a42 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -20,6 +20,7 @@ from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_wan import AutoencoderKLWan from .autoencoder_oobleck import AutoencoderOobleck +from .autoencoder_longcat_audio_dit import LongCatAudioDiTVae from .autoencoder_rae import AutoencoderRAE from .autoencoder_tiny import AutoencoderTiny from .autoencoder_vidtok import AutoencoderVidTok diff --git a/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py new file mode 100644 index 000000000000..9ab0a0d27470 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py @@ -0,0 +1,394 @@ +# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from the LongCat-AudioDiT reference implementation: +# https://github.com/meituan-longcat/LongCat-AudioDiT + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.torch_utils import randn_tensor +from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin + + +def _wn_conv1d(in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, bias=True): + return weight_norm(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)) + + +def _wn_conv_transpose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +class Snake1d(nn.Module): + def __init__(self, channels: int, alpha_logscale: bool = True): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter(torch.zeros(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + alpha = self.alpha[None, :, None] + beta = self.beta[None, :, None] + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + return hidden_states + (1.0 / (beta + 1e-9)) * torch.sin(hidden_states * alpha).pow(2) + + +def _get_vae_activation(name: str, channels: int = 0) -> nn.Module: + if name == "elu": + return nn.ELU() + if name == "snake": + return Snake1d(channels) + raise ValueError(f"Unknown activation: {name}") + + +def _pixel_unshuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor: + batch, channels, width = hidden_states.size() + return ( + hidden_states.view(batch, channels, width // factor, factor) + .permute(0, 1, 3, 2) + .contiguous() + .view(batch, channels * factor, width // factor) + ) + + +def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor: + batch, channels, width = hidden_states.size() + return ( + hidden_states.view(batch, channels // factor, factor, width) + .permute(0, 1, 3, 2) + .contiguous() + .view(batch, channels // factor, width * factor) + ) + + +class DownsampleShortcut(nn.Module): + def __init__(self, in_channels: int, out_channels: int, factor: int): + super().__init__() + self.factor = factor + self.group_size = in_channels * factor // out_channels + self.out_channels = out_channels + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = _pixel_unshuffle_1d(hidden_states, self.factor) + batch, _channels, width = hidden_states.shape + return hidden_states.view(batch, self.out_channels, self.group_size, width).mean(dim=2) + + +class UpsampleShortcut(nn.Module): + def __init__(self, in_channels: int, out_channels: int, factor: int): + super().__init__() + self.factor = factor + self.repeats = out_channels * factor // in_channels + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.repeat_interleave(self.repeats, dim=1) + return _pixel_shuffle_1d(hidden_states, self.factor) + + +class VaeResidualUnit(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, dilation: int, kernel_size: int = 7, use_snake: bool = False + ): + super().__init__() + padding = (dilation * (kernel_size - 1)) // 2 + activation = "snake" if use_snake else "elu" + self.layers = nn.Sequential( + _get_vae_activation(activation, channels=out_channels), + _wn_conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=padding), + _get_vae_activation(activation, channels=out_channels), + _wn_conv1d(out_channels, out_channels, kernel_size=1), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return hidden_states + self.layers(hidden_states) + + +class VaeEncoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + use_snake: bool = False, + downsample_shortcut: str = "none", + ): + super().__init__() + layers = [ + VaeResidualUnit(in_channels, in_channels, dilation=1, use_snake=use_snake), + VaeResidualUnit(in_channels, in_channels, dilation=3, use_snake=use_snake), + VaeResidualUnit(in_channels, in_channels, dilation=9, use_snake=use_snake), + ] + activation = "snake" if use_snake else "elu" + layers.append(_get_vae_activation(activation, channels=in_channels)) + layers.append( + _wn_conv1d(in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)) + ) + self.layers = nn.Sequential(*layers) + self.residual = ( + DownsampleShortcut(in_channels, out_channels, stride) if downsample_shortcut == "averaging" else None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.residual is None: + return self.layers(hidden_states) + return self.layers(hidden_states) + self.residual(hidden_states) + + +class VaeDecoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + use_snake: bool = False, + upsample_shortcut: str = "none", + ): + super().__init__() + activation = "snake" if use_snake else "elu" + layers = [ + _get_vae_activation(activation, channels=in_channels), + _wn_conv_transpose1d( + in_channels, out_channels, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2) + ), + VaeResidualUnit(out_channels, out_channels, dilation=1, use_snake=use_snake), + VaeResidualUnit(out_channels, out_channels, dilation=3, use_snake=use_snake), + VaeResidualUnit(out_channels, out_channels, dilation=9, use_snake=use_snake), + ] + self.layers = nn.Sequential(*layers) + self.residual = ( + UpsampleShortcut(in_channels, out_channels, stride) if upsample_shortcut == "duplicating" else None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.residual is None: + return self.layers(hidden_states) + return self.layers(hidden_states) + self.residual(hidden_states) + + +class AudioDiTVaeEncoder(nn.Module): + def __init__( + self, + in_channels: int = 1, + channels: int = 128, + c_mults=None, + strides=None, + latent_dim: int = 64, + encoder_latent_dim: int = 128, + use_snake: bool = True, + downsample_shortcut: str = "averaging", + out_shortcut: str = "averaging", + ): + super().__init__() + c_mults = [1] + (c_mults or [1, 2, 4, 8, 16]) + strides = list(strides or [2] * (len(c_mults) - 1)) + if len(strides) < len(c_mults) - 1: + strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides))) + else: + strides = strides[: len(c_mults) - 1] + channels_base = channels + layers = [_wn_conv1d(in_channels, c_mults[0] * channels_base, kernel_size=7, padding=3)] + for idx in range(len(c_mults) - 1): + layers.append( + VaeEncoderBlock( + c_mults[idx] * channels_base, + c_mults[idx + 1] * channels_base, + strides[idx], + use_snake=use_snake, + downsample_shortcut=downsample_shortcut, + ) + ) + layers.append(_wn_conv1d(c_mults[-1] * channels_base, encoder_latent_dim, kernel_size=3, padding=1)) + self.layers = nn.Sequential(*layers) + self.shortcut = ( + DownsampleShortcut(c_mults[-1] * channels_base, encoder_latent_dim, 1) + if out_shortcut == "averaging" + else None + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.shortcut is None: + return self.layers(hidden_states) + hidden_states = self.layers[:-1](hidden_states) + return self.layers[-1](hidden_states) + self.shortcut(hidden_states) + + +class AudioDiTVaeDecoder(nn.Module): + def __init__( + self, + in_channels: int = 1, + channels: int = 128, + c_mults=None, + strides=None, + latent_dim: int = 64, + use_snake: bool = True, + in_shortcut: str = "duplicating", + final_tanh: bool = False, + upsample_shortcut: str = "duplicating", + ): + super().__init__() + c_mults = [1] + (c_mults or [1, 2, 4, 8, 16]) + strides = list(strides or [2] * (len(c_mults) - 1)) + if len(strides) < len(c_mults) - 1: + strides.extend([strides[-1] if strides else 2] * (len(c_mults) - 1 - len(strides))) + else: + strides = strides[: len(c_mults) - 1] + channels_base = channels + + self.shortcut = ( + UpsampleShortcut(latent_dim, c_mults[-1] * channels_base, 1) if in_shortcut == "duplicating" else None + ) + + layers = [_wn_conv1d(latent_dim, c_mults[-1] * channels_base, kernel_size=7, padding=3)] + for idx in range(len(c_mults) - 1, 0, -1): + layers.append( + VaeDecoderBlock( + c_mults[idx] * channels_base, + c_mults[idx - 1] * channels_base, + strides[idx - 1], + use_snake=use_snake, + upsample_shortcut=upsample_shortcut, + ) + ) + activation = "snake" if use_snake else "elu" + layers.append(_get_vae_activation(activation, channels=c_mults[0] * channels_base)) + layers.append(_wn_conv1d(c_mults[0] * channels_base, in_channels, kernel_size=7, padding=3, bias=False)) + layers.append(nn.Tanh() if final_tanh else nn.Identity()) + self.layers = nn.Sequential(*layers) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.shortcut is None: + return self.layers(hidden_states) + hidden_states = self.shortcut(hidden_states) + self.layers[0](hidden_states) + return self.layers[1:](hidden_states) + + +@dataclass +class LongCatAudioDiTVaeEncoderOutput(BaseOutput): + latents: torch.Tensor + + +@dataclass +class LongCatAudioDiTVaeDecoderOutput(BaseOutput): + sample: torch.Tensor + + +class LongCatAudioDiTVae(ModelMixin, AutoencoderMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 1, + channels: int = 128, + c_mults=None, + strides=None, + latent_dim: int = 64, + encoder_latent_dim: int = 128, + use_snake: bool = True, + downsample_shortcut: str = "averaging", + upsample_shortcut: str = "duplicating", + out_shortcut: str = "averaging", + in_shortcut: str = "duplicating", + final_tanh: bool = False, + downsampling_ratio: int = 2048, + sample_rate: int = 24000, + scale: float = 0.71, + ): + super().__init__() + self.encoder = AudioDiTVaeEncoder( + in_channels=in_channels, + channels=channels, + c_mults=c_mults, + strides=strides, + latent_dim=latent_dim, + encoder_latent_dim=encoder_latent_dim, + use_snake=use_snake, + downsample_shortcut=downsample_shortcut, + out_shortcut=out_shortcut, + ) + self.decoder = AudioDiTVaeDecoder( + in_channels=in_channels, + channels=channels, + c_mults=c_mults, + strides=strides, + latent_dim=latent_dim, + use_snake=use_snake, + in_shortcut=in_shortcut, + final_tanh=final_tanh, + upsample_shortcut=upsample_shortcut, + ) + + @property + def sampling_rate(self) -> int: + return self.config.sample_rate + + def encode( + self, + sample: torch.Tensor, + sample_posterior: bool = True, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> LongCatAudioDiTVaeEncoderOutput | tuple[torch.Tensor]: + encoder_dtype = next(self.encoder.parameters()).dtype + if sample.dtype != encoder_dtype: + sample = sample.to(encoder_dtype) + encoded = self.encoder(sample) + mean, scale_param = encoded.chunk(2, dim=1) + std = F.softplus(scale_param) + 1e-4 + if sample_posterior: + noise = randn_tensor(mean.shape, generator=generator, device=mean.device, dtype=mean.dtype) + latents = mean + std * noise + else: + latents = mean + latents = latents / self.config.scale + if encoder_dtype != torch.float32: + latents = latents.float() + if not return_dict: + return (latents,) + return LongCatAudioDiTVaeEncoderOutput(latents=latents) + + def decode( + self, latents: torch.Tensor, return_dict: bool = True + ) -> LongCatAudioDiTVaeDecoderOutput | tuple[torch.Tensor]: + decoder_dtype = next(self.decoder.parameters()).dtype + latents = latents * self.config.scale + if latents.dtype != decoder_dtype: + latents = latents.to(decoder_dtype) + decoded = self.decoder(latents) + if decoder_dtype != torch.float32: + decoded = decoded.float() + if not return_dict: + return (decoded,) + return LongCatAudioDiTVaeDecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: torch.Generator | None = None, + ) -> LongCatAudioDiTVaeDecoderOutput | tuple[torch.Tensor]: + latents = self.encode(sample, sample_posterior=sample_posterior, return_dict=True, generator=generator).latents + decoded = self.decode(latents, return_dict=True).sample + if not return_dict: + return (decoded,) + return LongCatAudioDiTVaeDecoderOutput(sample=decoded) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 45157ee91808..ae91c5a54e49 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -38,6 +38,7 @@ from .transformer_longcat_image import LongCatImageTransformer2DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_ltx2 import LTX2VideoTransformer3DModel + from .transformer_longcat_audio_dit import LongCatAudioDiTTransformer from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel diff --git a/src/diffusers/models/transformers/transformer_longcat_audio_dit.py b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py new file mode 100644 index 000000000000..f9e9388aff55 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_longcat_audio_dit.py @@ -0,0 +1,639 @@ +# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from the LongCat-AudioDiT reference implementation: +# https://github.com/meituan-longcat/LongCat-AudioDiT + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionModuleMixin +from ..attention_dispatch import dispatch_attention_fn +from ..modeling_utils import ModelMixin + + +@dataclass +class LongCatAudioDiTTransformerOutput(BaseOutput): + sample: torch.Tensor + + +class AudioDiTRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + normalized = hidden_states.float() * torch.rsqrt( + hidden_states.float().pow(2).mean(dim=-1, keepdim=True) + self.eps + ) + return normalized.to(hidden_states.dtype) * self.weight + + +class AudioDiTSinusPositionEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, timesteps: torch.Tensor, scale: float = 1000.0) -> torch.Tensor: + device = timesteps.device + half_dim = self.dim // 2 + exponent = math.log(10000) / max(half_dim - 1, 1) + embeddings = torch.exp(torch.arange(half_dim, device=device).float() * -exponent) + embeddings = scale * timesteps.unsqueeze(1) * embeddings.unsqueeze(0) + return torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) + + +class AudioDiTTimestepEmbedding(nn.Module): + def __init__(self, dim: int, freq_embed_dim: int = 256): + super().__init__() + self.time_embed = AudioDiTSinusPositionEmbedding(freq_embed_dim) + self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + + def forward(self, timestep: torch.Tensor) -> torch.Tensor: + hidden_states = self.time_embed(timestep) + return self.time_mlp(hidden_states.to(timestep.dtype)) + + +class AudioDiTRotaryEmbedding(nn.Module): + def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 100000.0): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self._cos = None + self._sin = None + self._cached_len = 0 + self._cached_device = None + + def _build(self, seq_len: int, device: torch.device, dtype: torch.dtype): + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + steps = torch.arange(seq_len, dtype=torch.int64).type_as(inv_freq) + freqs = torch.outer(steps, inv_freq) + embeddings = torch.cat((freqs, freqs), dim=-1) + self._cos = embeddings.cos().to(dtype=dtype, device=device) + self._sin = embeddings.sin().to(dtype=dtype, device=device) + self._cached_len = seq_len + self._cached_device = device + + def forward(self, hidden_states: torch.Tensor, seq_len: int | None = None) -> tuple[torch.Tensor, torch.Tensor]: + seq_len = hidden_states.shape[1] if seq_len is None else seq_len + if self._cos is None or seq_len > self._cached_len or self._cached_device != hidden_states.device: + self._build(max(seq_len, self.max_position_embeddings), hidden_states.device, hidden_states.dtype) + return self._cos[:seq_len].to(hidden_states.dtype), self._sin[:seq_len].to(hidden_states.dtype) + + +def _rotate_half(hidden_states: torch.Tensor) -> torch.Tensor: + first, second = hidden_states.chunk(2, dim=-1) + return torch.cat((-second, first), dim=-1) + + +def _apply_rotary_emb(hidden_states: torch.Tensor, rope: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + cos, sin = rope + cos = cos[None, :, None].to(hidden_states.device) + sin = sin[None, :, None].to(hidden_states.device) + return (hidden_states.float() * cos + _rotate_half(hidden_states).float() * sin).to(hidden_states.dtype) + + +class AudioDiTGRN(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + gx = torch.norm(hidden_states, p=2, dim=1, keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (hidden_states * nx) + self.beta + hidden_states + + +class AudioDiTConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + kernel_size: int = 7, + bias: bool = True, + eps: float = 1e-6, + ): + super().__init__() + padding = (dilation * (kernel_size - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=kernel_size, padding=padding, groups=dim, dilation=dilation, bias=bias + ) + self.norm = nn.LayerNorm(dim, eps=eps) + self.pwconv1 = nn.Linear(dim, intermediate_dim, bias=bias) + self.act = nn.SiLU() + self.grn = AudioDiTGRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim, bias=bias) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + hidden_states = self.dwconv(hidden_states.transpose(1, 2)).transpose(1, 2) + hidden_states = self.norm(hidden_states) + hidden_states = self.pwconv1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.grn(hidden_states) + hidden_states = self.pwconv2(hidden_states) + return residual + hidden_states + + +class AudioDiTEmbedder(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.proj = nn.Sequential(nn.Linear(in_dim, out_dim), nn.SiLU(), nn.Linear(out_dim, out_dim)) + + def forward(self, hidden_states: torch.Tensor, mask: torch.BoolTensor | None = None) -> torch.Tensor: + if mask is not None: + hidden_states = hidden_states.masked_fill(mask.logical_not().unsqueeze(-1), 0.0) + hidden_states = self.proj(hidden_states) + if mask is not None: + hidden_states = hidden_states.masked_fill(mask.logical_not().unsqueeze(-1), 0.0) + return hidden_states + + +class AudioDiTAdaLNMLP(nn.Module): + def __init__(self, in_dim: int, out_dim: int, bias: bool = True): + super().__init__() + self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(in_dim, out_dim, bias=bias)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.mlp(hidden_states) + + +class AudioDiTAdaLayerNormZeroFinal(nn.Module): + def __init__(self, dim: int, bias: bool = True, eps: float = 1e-6): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(dim, dim * 2, bias=bias) + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) + + def forward(self, hidden_states: torch.Tensor, embedding: torch.Tensor) -> torch.Tensor: + embedding = self.linear(self.silu(embedding)) + scale, shift = torch.chunk(embedding, 2, dim=-1) + hidden_states = self.norm(hidden_states.float()).type_as(hidden_states) + if scale.ndim == 2: + hidden_states = hidden_states * (1 + scale)[:, None, :] + shift[:, None, :] + else: + hidden_states = hidden_states * (1 + scale) + shift + return hidden_states + + +def _modulate( + hidden_states: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + hidden_states = F.layer_norm(hidden_states.float(), (hidden_states.shape[-1],), eps=eps).type_as(hidden_states) + if scale.ndim == 2: + return hidden_states * (1 + scale[:, None]) + shift[:, None] + return hidden_states * (1 + scale) + shift + + +class AudioDiTSelfAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: "AudioDiTSelfAttention", + hidden_states: torch.Tensor, + mask: torch.BoolTensor | None = None, + rope: tuple | None = None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + if attn.qk_norm: + query = attn.q_norm(query) + key = attn.k_norm(key) + + head_dim = attn.inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + if rope is not None: + query = _apply_rotary_emb(query, rope) + key = _apply_rotary_emb(key, rope) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + if mask is not None: + hidden_states = hidden_states * mask[:, :, None, None].to(hidden_states.dtype) + + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class AudioDiTSelfAttention(nn.Module, AttentionModuleMixin): + _default_processor_cls = AudioDiTSelfAttnProcessor + _available_processors = [AudioDiTSelfAttnProcessor] + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + dropout: float = 0.0, + bias: bool = True, + qk_norm: bool = False, + eps: float = 1e-6, + ): + super().__init__() + self.heads = heads + self.inner_dim = dim_head * heads + self.to_q = nn.Linear(dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(dim, self.inner_dim, bias=bias) + self.qk_norm = qk_norm + if qk_norm: + self.q_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) + self.k_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, dim, bias=bias), nn.Dropout(dropout)]) + self.set_processor(self._default_processor_cls()) + + def forward( + self, hidden_states: torch.Tensor, mask: torch.BoolTensor | None = None, rope: tuple | None = None + ) -> torch.Tensor: + return self.processor(self, hidden_states, mask=mask, rope=rope) + + +class AudioDiTCrossAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __call__( + self, + attn: "AudioDiTCrossAttention", + hidden_states: torch.Tensor, + cond: torch.Tensor, + mask: torch.BoolTensor | None = None, + cond_mask: torch.BoolTensor | None = None, + rope: tuple | None = None, + cond_rope: tuple | None = None, + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + query = attn.to_q(hidden_states) + key = attn.to_k(cond) + value = attn.to_v(cond) + + if attn.qk_norm: + query = attn.q_norm(query) + key = attn.k_norm(key) + + head_dim = attn.inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, attn.heads, head_dim) + value = value.view(batch_size, -1, attn.heads, head_dim) + + if rope is not None: + query = _apply_rotary_emb(query, rope) + if cond_rope is not None: + key = _apply_rotary_emb(key, cond_rope) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=cond_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + if mask is not None: + hidden_states = hidden_states * mask[:, :, None, None].to(hidden_states.dtype) + + hidden_states = hidden_states.flatten(2, 3).to(query.dtype) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class AudioDiTCrossAttention(nn.Module, AttentionModuleMixin): + _default_processor_cls = AudioDiTCrossAttnProcessor + _available_processors = [AudioDiTCrossAttnProcessor] + def __init__( + self, + q_dim: int, + kv_dim: int, + heads: int, + dim_head: int, + dropout: float = 0.0, + bias: bool = True, + qk_norm: bool = False, + eps: float = 1e-6, + ): + super().__init__() + self.heads = heads + self.inner_dim = dim_head * heads + self.to_q = nn.Linear(q_dim, self.inner_dim, bias=bias) + self.to_k = nn.Linear(kv_dim, self.inner_dim, bias=bias) + self.to_v = nn.Linear(kv_dim, self.inner_dim, bias=bias) + self.qk_norm = qk_norm + if qk_norm: + self.q_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) + self.k_norm = AudioDiTRMSNorm(self.inner_dim, eps=eps) + self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, q_dim, bias=bias), nn.Dropout(dropout)]) + self.set_processor(self._default_processor_cls()) + + def forward( + self, + hidden_states: torch.Tensor, + cond: torch.Tensor, + mask: torch.BoolTensor | None = None, + cond_mask: torch.BoolTensor | None = None, + rope: tuple | None = None, + cond_rope: tuple | None = None, + ) -> torch.Tensor: + return self.processor( + self, + hidden_states, + cond=cond, + mask=mask, + cond_mask=cond_mask, + rope=rope, + cond_rope=cond_rope, + ) + + +class AudioDiTFeedForward(nn.Module): + def __init__(self, dim: int, mult: float = 4.0, dropout: float = 0.0, bias: bool = True): + super().__init__() + inner_dim = int(dim * mult) + self.ff = nn.Sequential( + nn.Linear(dim, inner_dim, bias=bias), + nn.GELU(approximate="tanh"), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim, bias=bias), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.ff(hidden_states) + + +@maybe_allow_in_graph +class AudioDiTBlock(nn.Module): + def __init__( + self, + dim: int, + cond_dim: int, + heads: int, + dim_head: int, + dropout: float = 0.0, + bias: bool = True, + qk_norm: bool = False, + eps: float = 1e-6, + cross_attn: bool = True, + cross_attn_norm: bool = False, + adaln_type: str = "global", + adaln_use_text_cond: bool = True, + ff_mult: float = 4.0, + ): + super().__init__() + self.adaln_type = adaln_type + self.adaln_use_text_cond = adaln_use_text_cond + if adaln_type == "local": + self.adaln_mlp = AudioDiTAdaLNMLP(dim, dim * 6, bias=True) + elif adaln_type == "global": + self.adaln_scale_shift = nn.Parameter(torch.randn(dim * 6) / dim**0.5) + self.self_attn = AudioDiTSelfAttention( + dim, heads, dim_head, dropout=dropout, bias=bias, qk_norm=qk_norm, eps=eps + ) + self.use_cross_attn = cross_attn + if cross_attn: + self.cross_attn = AudioDiTCrossAttention( + dim, cond_dim, heads, dim_head, dropout=dropout, bias=bias, qk_norm=qk_norm, eps=eps + ) + self.cross_attn_norm = ( + nn.LayerNorm(dim, elementwise_affine=True, eps=eps) if cross_attn_norm else nn.Identity() + ) + self.cross_attn_norm_c = ( + nn.LayerNorm(cond_dim, elementwise_affine=True, eps=eps) if cross_attn_norm else nn.Identity() + ) + self.ffn = AudioDiTFeedForward(dim=dim, mult=ff_mult, dropout=dropout, bias=bias) + + def forward( + self, + hidden_states: torch.Tensor, + timestep_embed: torch.Tensor, + cond: torch.Tensor, + mask: torch.BoolTensor | None = None, + cond_mask: torch.BoolTensor | None = None, + rope: tuple | None = None, + cond_rope: tuple | None = None, + adaln_global_out: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.adaln_type == "local" and adaln_global_out is None: + if self.adaln_use_text_cond: + denom = cond_mask.sum(1, keepdim=True).clamp(min=1).to(cond.dtype) + cond_mean = cond.sum(1) / denom + norm_cond = timestep_embed + cond_mean + else: + norm_cond = timestep_embed + adaln_out = self.adaln_mlp(norm_cond) + gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1) + else: + adaln_out = adaln_global_out + self.adaln_scale_shift.unsqueeze(0) + gate_sa, scale_sa, shift_sa, gate_ffn, scale_ffn, shift_ffn = torch.chunk(adaln_out, 6, dim=-1) + + norm_hidden_states = _modulate(hidden_states, scale_sa, shift_sa) + attn_output = self.self_attn(norm_hidden_states, mask=mask, rope=rope) + hidden_states = hidden_states + gate_sa.unsqueeze(1) * attn_output + + if self.use_cross_attn: + cross_output = self.cross_attn( + hidden_states=self.cross_attn_norm(hidden_states), + cond=self.cross_attn_norm_c(cond), + mask=mask, + cond_mask=cond_mask, + rope=rope, + cond_rope=cond_rope, + ) + hidden_states = hidden_states + cross_output + + norm_hidden_states = _modulate(hidden_states, scale_ffn, shift_ffn) + ff_output = self.ffn(norm_hidden_states) + hidden_states = hidden_states + gate_ffn.unsqueeze(1) * ff_output + return hidden_states + + +class LongCatAudioDiTTransformer(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + dit_dim: int = 1536, + dit_depth: int = 24, + dit_heads: int = 24, + dit_text_dim: int = 768, + latent_dim: int = 64, + dropout: float = 0.0, + bias: bool = True, + cross_attn: bool = True, + adaln_type: str = "global", + adaln_use_text_cond: bool = True, + long_skip: bool = True, + text_conv: bool = True, + qk_norm: bool = True, + cross_attn_norm: bool = False, + eps: float = 1e-6, + use_latent_condition: bool = True, + ): + super().__init__() + dim = dit_dim + dim_head = dim // dit_heads + self.long_skip = long_skip + self.adaln_type = adaln_type + self.adaln_use_text_cond = adaln_use_text_cond + self.time_embed = AudioDiTTimestepEmbedding(dim) + self.input_embed = AudioDiTEmbedder(latent_dim, dim) + self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) + self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) + self.blocks = nn.ModuleList( + [ + AudioDiTBlock( + dim=dim, + cond_dim=dim, + heads=dit_heads, + dim_head=dim_head, + dropout=dropout, + bias=bias, + qk_norm=qk_norm, + eps=eps, + cross_attn=cross_attn, + cross_attn_norm=cross_attn_norm, + adaln_type=adaln_type, + adaln_use_text_cond=adaln_use_text_cond, + ff_mult=4.0, + ) + for _ in range(dit_depth) + ] + ) + self.norm_out = AudioDiTAdaLayerNormZeroFinal(dim, bias=bias, eps=eps) + self.proj_out = nn.Linear(dim, latent_dim) + if adaln_type == "global": + self.adaln_global_mlp = AudioDiTAdaLNMLP(dim, dim * 6, bias=True) + self.text_conv = text_conv + if text_conv: + self.text_conv_layer = nn.Sequential( + *[AudioDiTConvNeXtV2Block(dim, dim * 2, bias=bias, eps=eps) for _ in range(4)] + ) + self.use_latent_condition = use_latent_condition + if use_latent_condition: + self.latent_embed = AudioDiTEmbedder(latent_dim, dim) + self.latent_cond_embedder = AudioDiTEmbedder(dim * 2, dim) + self._initialize_weights(bias=bias) + + def _initialize_weights(self, bias: bool = True): + if self.adaln_type == "local": + for block in self.blocks: + nn.init.constant_(block.adaln_mlp.mlp[-1].weight, 0) + if bias: + nn.init.constant_(block.adaln_mlp.mlp[-1].bias, 0) + elif self.adaln_type == "global": + nn.init.constant_(self.adaln_global_mlp.mlp[-1].weight, 0) + if bias: + nn.init.constant_(self.adaln_global_mlp.mlp[-1].bias, 0) + nn.init.constant_(self.norm_out.linear.weight, 0) + nn.init.constant_(self.proj_out.weight, 0) + if bias: + nn.init.constant_(self.norm_out.linear.bias, 0) + nn.init.constant_(self.proj_out.bias, 0) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.BoolTensor, + timestep: torch.Tensor, + attention_mask: torch.BoolTensor | None = None, + latent_cond: torch.Tensor | None = None, + return_dict: bool = True, + ) -> LongCatAudioDiTTransformerOutput | tuple[torch.Tensor]: + dtype = next(self.parameters()).dtype + hidden_states = hidden_states.to(dtype) + encoder_hidden_states = encoder_hidden_states.to(dtype) + timestep = timestep.to(dtype) + batch_size = hidden_states.shape[0] + if timestep.ndim == 0: + timestep = timestep.repeat(batch_size) + timestep_embed = self.time_embed(timestep) + text_mask = encoder_attention_mask.bool() + encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) + if self.text_conv: + encoder_hidden_states = self.text_conv_layer(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.masked_fill(text_mask.logical_not().unsqueeze(-1), 0.0) + hidden_states = self.input_embed(hidden_states, attention_mask) + if self.use_latent_condition and latent_cond is not None: + latent_cond = self.latent_embed(latent_cond.to(dtype), attention_mask) + hidden_states = self.latent_cond_embedder(torch.cat([hidden_states, latent_cond], dim=-1)) + residual = hidden_states.clone() if self.long_skip else None + rope = self.rotary_embed(hidden_states, hidden_states.shape[1]) + cond_rope = self.rotary_embed(encoder_hidden_states, encoder_hidden_states.shape[1]) + if self.adaln_type == "global": + if self.adaln_use_text_cond: + text_len = text_mask.sum(1).clamp(min=1).to(encoder_hidden_states.dtype) + text_mean = encoder_hidden_states.sum(1) / text_len.unsqueeze(1) + norm_cond = timestep_embed + text_mean + else: + norm_cond = timestep_embed + adaln_global_out = self.adaln_global_mlp(norm_cond) + for block in self.blocks: + hidden_states = block( + hidden_states=hidden_states, + timestep_embed=timestep_embed, + cond=encoder_hidden_states, + mask=attention_mask, + cond_mask=text_mask, + rope=rope, + cond_rope=cond_rope, + adaln_global_out=adaln_global_out, + ) + else: + norm_cond = timestep_embed + for block in self.blocks: + hidden_states = block( + hidden_states=hidden_states, + timestep_embed=timestep_embed, + cond=encoder_hidden_states, + mask=attention_mask, + cond_mask=text_mask, + rope=rope, + cond_rope=cond_rope, + ) + if self.long_skip: + hidden_states = hidden_states + residual + hidden_states = self.norm_out(hidden_states, norm_cond) + hidden_states = self.proj_out(hidden_states) + if attention_mask is not None: + hidden_states = hidden_states * attention_mask.unsqueeze(-1).to(hidden_states.dtype) + if not return_dict: + return (hidden_states,) + return LongCatAudioDiTTransformerOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 05aad6e349f6..154c28d6bc24 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -326,6 +326,7 @@ _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["lucy"] = ["LucyEditPipeline"] _import_structure["longcat_image"] = ["LongCatImagePipeline", "LongCatImageEditPipeline"] + _import_structure["longcat_audio_dit"] = ["LongCatAudioDiTPipeline"] _import_structure["marigold"].extend( [ "MarigoldDepthPipeline", @@ -751,6 +752,7 @@ ) from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline + from .longcat_audio_dit import LongCatAudioDiTPipeline from .ltx import ( LTXConditionPipeline, LTXI2VLongMultiPromptPipeline, diff --git a/src/diffusers/pipelines/longcat_audio_dit/__init__.py b/src/diffusers/pipelines/longcat_audio_dit/__init__.py new file mode 100644 index 000000000000..61cb89b4140f --- /dev/null +++ b/src/diffusers/pipelines/longcat_audio_dit/__init__.py @@ -0,0 +1,40 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure['pipeline_longcat_audio_dit'] = ['LongCatAudioDiTPipeline'] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_longcat_audio_dit import LongCatAudioDiTPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()['__file__'], _import_structure, module_spec=__spec__) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py new file mode 100644 index 000000000000..938a33106b5d --- /dev/null +++ b/src/diffusers/pipelines/longcat_audio_dit/pipeline_longcat_audio_dit.py @@ -0,0 +1,483 @@ +# Copyright 2026 MeiTuan LongCat-AudioDiT Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from the LongCat-AudioDiT reference implementation: +# https://github.com/meituan-longcat/LongCat-AudioDiT + +import json +import re +from pathlib import Path +from typing import Any + +import torch +import torch.nn.functional as F +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import validate_hf_hub_args +from safetensors.torch import load_file +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer, PreTrainedTokenizerBase, UMT5Config, UMT5EncoderModel + +from ...models import LongCatAudioDiTTransformer, LongCatAudioDiTVae +from ...utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline + + +logger = logging.get_logger(__name__) + + +def _lens_to_mask(lengths: torch.Tensor, length: int | None = None) -> torch.BoolTensor: + if length is None: + length = int(lengths.amax().item()) + seq = torch.arange(length, device=lengths.device) + return seq[None, :] < lengths[:, None] + + +def _normalize_text(text: str) -> str: + text = text.lower() + text = re.sub(r'["β€œβ€β€˜β€™]', " ", text) + text = re.sub(r"\s+", " ", text) + return text.strip() + + +def _approx_duration_from_text(text: str, max_duration: float = 30.0) -> float: + en_dur_per_char = 0.082 + zh_dur_per_char = 0.21 + text = re.sub(r"\s+", "", text) + num_zh = num_en = num_other = 0 + for char in text: + if "δΈ€" <= char <= "ιΏΏ": + num_zh += 1 + elif char.isalpha(): + num_en += 1 + else: + num_other += 1 + if num_zh > num_en: + num_zh += num_other + else: + num_en += num_other + return min(max_duration, num_zh * zh_dur_per_char + num_en * en_dur_per_char) + + +def _approx_batch_duration_from_prompts(prompts: list[str]) -> float: + if not prompts: + return 0.0 + return max(_approx_duration_from_text(prompt) for prompt in prompts) + + +def _extract_prefixed_state_dict(state_dict: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]: + prefix = f"{prefix}." + return {key[len(prefix):]: value for key, value in state_dict.items() if key.startswith(prefix)} + + +def _load_longcat_tokenizer( + pretrained_model_name_or_path: str | Path, + text_encoder_model: str | None, + tokenizer: PreTrainedTokenizerBase | str | Path | None, + local_files_only: bool | None, + subfolder: str | None = None, +) -> PreTrainedTokenizerBase: + if isinstance(tokenizer, PreTrainedTokenizerBase): + return tokenizer + + tokenizer_source: str | Path | None = tokenizer + if tokenizer_source is None: + pretrained_path = Path(pretrained_model_name_or_path) + local_tokenizer_dir = pretrained_path / (subfolder or "") / "tokenizer" + if pretrained_path.exists() and local_tokenizer_dir.is_dir(): + tokenizer_source = local_tokenizer_dir + else: + tokenizer_source = text_encoder_model or pretrained_model_name_or_path + + if tokenizer_source is None: + raise ValueError("Could not determine tokenizer source for LongCatAudioDiT.") + + tokenizer_kwargs = {"local_files_only": local_files_only} + if not isinstance(tokenizer_source, Path) and tokenizer_source == pretrained_model_name_or_path and subfolder: + tokenizer_kwargs["subfolder"] = subfolder + return AutoTokenizer.from_pretrained(tokenizer_source, **tokenizer_kwargs) + + +def _resolve_longcat_file( + pretrained_model_name_or_path: str | Path, + filename: str, + cache_dir: str | Path | None = None, + force_download: bool = False, + proxies: dict[str, str] | None = None, + local_files_only: bool | None = None, + token: str | bool | None = None, + revision: str | None = None, + subfolder: str | None = None, + local_dir: str | Path | None = None, + local_dir_use_symlinks: str | bool = "auto", + user_agent: dict[str, str] | None = None, +) -> str: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if Path(pretrained_model_name_or_path).is_dir(): + base = Path(pretrained_model_name_or_path) + if subfolder is not None: + base = base / subfolder + file_path = base / filename + if not file_path.is_file(): + raise EnvironmentError(f"Error no file named {filename} found in directory {base}.") + return str(file_path) + + try: + return hf_hub_download( + pretrained_model_name_or_path, + filename=filename, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + subfolder=subfolder, + revision=revision, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + user_agent=user_agent, + ) + except Exception as err: + raise EnvironmentError( + f"Can't load {filename} for '{pretrained_model_name_or_path}'. If you were trying to load it from " + f"'{HUGGINGFACE_CO_RESOLVE_ENDPOINT}', make sure the repo exists or that your local path is correct." + ) from err + + +class LongCatAudioDiTPipeline(DiffusionPipeline): + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + vae: LongCatAudioDiTVae, + text_encoder: UMT5EncoderModel, + tokenizer: PreTrainedTokenizerBase, + transformer: LongCatAudioDiTTransformer, + ): + super().__init__() + self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer) + self.sample_rate = getattr(vae.config, "sample_rate", 24000) + self.latent_hop = getattr(vae.config, "downsampling_ratio", 2048) + self.latent_dim = getattr(transformer.config, "latent_dim", 64) + self.max_wav_duration = 30.0 + self.text_norm_feat = True + self.text_add_embed = True + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + tokenizer: PreTrainedTokenizerBase | str | Path | None = None, + torch_dtype: torch.dtype | None = None, + local_files_only: bool | None = None, + **kwargs: Any, + ) -> "LongCatAudioDiTPipeline": + cache_dir = kwargs.pop("cache_dir", None) + local_dir = kwargs.pop("local_dir", None) + local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto") + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + try: + cls.load_config( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + ) + except (EnvironmentError, OSError, ValueError): + pass + else: + return super().from_pretrained( + pretrained_model_name_or_path, + tokenizer=tokenizer, + torch_dtype=torch_dtype, + local_files_only=local_files_only, + cache_dir=cache_dir, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + **kwargs, + ) + + if kwargs: + logger.warning("Ignoring unsupported LongCatAudioDiTPipeline.from_pretrained kwargs: %s", sorted(kwargs)) + + config_path = _resolve_longcat_file( + pretrained_model_name_or_path, + "config.json", + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + user_agent={"file_type": "config"}, + ) + weights_path = _resolve_longcat_file( + pretrained_model_name_or_path, + "model.safetensors", + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + user_agent={"file_type": "weights"}, + ) + + with open(config_path) as handle: + config = json.load(handle) + + text_encoder_config = UMT5Config.from_dict(config["text_encoder_config"]) + text_encoder = UMT5EncoderModel(text_encoder_config) + transformer = LongCatAudioDiTTransformer( + dit_dim=config["dit_dim"], + dit_depth=config["dit_depth"], + dit_heads=config["dit_heads"], + dit_text_dim=config["dit_text_dim"], + latent_dim=config["latent_dim"], + dropout=config.get("dit_dropout", 0.0), + bias=config.get("dit_bias", True), + cross_attn=config.get("dit_cross_attn", True), + adaln_type=config.get("dit_adaln_type", "global"), + adaln_use_text_cond=config.get("dit_adaln_use_text_cond", True), + long_skip=config.get("dit_long_skip", True), + text_conv=config.get("dit_text_conv", True), + qk_norm=config.get("dit_qk_norm", True), + cross_attn_norm=config.get("dit_cross_attn_norm", False), + eps=config.get("dit_eps", 1e-6), + use_latent_condition=config.get("dit_use_latent_condition", True), + ) + vae_config = dict(config["vae_config"]) + vae_config.pop("model_type", None) + vae = LongCatAudioDiTVae(**vae_config) + + state_dict = load_file(weights_path) + transformer.load_state_dict(_extract_prefixed_state_dict(state_dict, "transformer"), strict=True) + vae.load_state_dict(_extract_prefixed_state_dict(state_dict, "vae"), strict=True) + text_missing, text_unexpected = text_encoder.load_state_dict( + _extract_prefixed_state_dict(state_dict, "text_encoder"), strict=False + ) + allowed_missing = {"shared.weight"} + unexpected_missing = set(text_missing) - allowed_missing + if unexpected_missing: + raise RuntimeError(f"Unexpected missing LongCatAudioDiT text encoder weights: {sorted(unexpected_missing)}") + if text_unexpected: + raise RuntimeError(f"Unexpected LongCatAudioDiT text encoder weights: {sorted(text_unexpected)}") + if "shared.weight" in text_missing: + text_encoder.shared.weight.data.copy_(text_encoder.encoder.embed_tokens.weight.data) + + tokenizer = _load_longcat_tokenizer( + pretrained_model_name_or_path, + config.get("text_encoder_model"), + tokenizer, + local_files_only=local_files_only, + subfolder=subfolder, + ) + + if torch_dtype is not None: + text_encoder = text_encoder.to(dtype=torch_dtype) + transformer = transformer.to(dtype=torch_dtype) + vae = vae.to(dtype=torch_dtype) + + text_encoder.eval() + transformer.eval() + vae.eval() + + pipe = cls(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer) + pipe.sample_rate = config.get("sampling_rate", pipe.sample_rate) + pipe.latent_hop = config.get("latent_hop", pipe.latent_hop) + pipe.max_wav_duration = config.get("max_wav_duration", pipe.max_wav_duration) + pipe.text_norm_feat = config.get("text_norm_feat", pipe.text_norm_feat) + pipe.text_add_embed = config.get("text_add_embed", pipe.text_add_embed) + return pipe + + def encode_prompt(self, prompt: str | list[str], device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + if isinstance(prompt, str): + prompt = [prompt] + model_max_length = getattr(self.tokenizer, "model_max_length", 512) + if not isinstance(model_max_length, int) or model_max_length <= 0 or model_max_length > 32768: + model_max_length = 512 + text_inputs = self.tokenizer( + prompt, + padding="longest", + truncation=True, + max_length=model_max_length, + return_tensors="pt", + ) + input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + with torch.no_grad(): + output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) + prompt_embeds = output.last_hidden_state + if self.text_norm_feat: + prompt_embeds = F.layer_norm(prompt_embeds, (prompt_embeds.shape[-1],), eps=1e-6) + if self.text_add_embed and getattr(output, "hidden_states", None): + first_hidden = output.hidden_states[0] + if self.text_norm_feat: + first_hidden = F.layer_norm(first_hidden, (first_hidden.shape[-1],), eps=1e-6) + prompt_embeds = prompt_embeds + first_hidden + lengths = attention_mask.sum(dim=1).to(device) + return prompt_embeds.float(), lengths + + def prepare_latents( + self, + batch_size: int, + duration: int, + device: torch.device, + dtype: torch.dtype, + generator: torch.Generator | list[torch.Generator] | None = None, + ) -> torch.Tensor: + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"Expected {batch_size} generators for batch size {batch_size}, but got {len(generator)}." + ) + generators = generator + else: + generators = [generator] * batch_size + + latents = [ + torch.randn( + duration, + self.latent_dim, + device=device, + dtype=dtype, + generator=generators[idx], + ) + for idx in range(batch_size) + ] + return pad_sequence(latents, padding_value=0.0, batch_first=True) + + @torch.no_grad() + def __call__( + self, + prompt: str | list[str], + negative_prompt: str | list[str] | None = None, + audio_end_in_s: float | None = None, + duration: int | None = None, + num_inference_steps: int = 16, + guidance_scale: float = 4.0, + generator: torch.Generator | list[torch.Generator] | None = None, + output_type: str = "np", + return_dict: bool = True, + ): + if prompt is None: + prompt = [] + elif isinstance(prompt, str): + prompt = [prompt] + else: + prompt = list(prompt) + batch_size = len(prompt) + if batch_size == 0: + raise ValueError("`prompt` must contain at least one prompt.") + + device = self._execution_device + normalized_prompts = [_normalize_text(text) for text in prompt] + if duration is None: + if audio_end_in_s is not None: + duration = int(audio_end_in_s * self.sample_rate // self.latent_hop) + else: + duration = int( + _approx_batch_duration_from_prompts(normalized_prompts) * self.sample_rate // self.latent_hop + ) + max_duration = int(self.max_wav_duration * self.sample_rate // self.latent_hop) + duration = max(1, min(duration, max_duration)) + + text_condition, text_condition_len = self.encode_prompt(normalized_prompts, device) + duration_tensor = torch.full((batch_size,), duration, device=device, dtype=torch.long) + mask = _lens_to_mask(duration_tensor) + text_mask = _lens_to_mask(text_condition_len, length=text_condition.shape[1]) + + if negative_prompt is None: + neg_text = torch.zeros_like(text_condition) + neg_text_len = text_condition_len + neg_text_mask = text_mask + else: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + else: + negative_prompt = list(negative_prompt) + if len(negative_prompt) != batch_size: + raise ValueError( + f"`negative_prompt` must have batch size {batch_size}, but got {len(negative_prompt)} prompts." + ) + neg_text, neg_text_len = self.encode_prompt(negative_prompt, device) + neg_text_mask = _lens_to_mask(neg_text_len, length=neg_text.shape[1]) + + latent_cond = torch.zeros(batch_size, duration, self.latent_dim, device=device, dtype=text_condition.dtype) + latents = self.prepare_latents(batch_size, duration, device, text_condition.dtype, generator=generator) + num_inference_steps = max(2, num_inference_steps) + timesteps = torch.linspace(0, 1, num_inference_steps, device=device, dtype=text_condition.dtype) + sample = latents + + def model_step(curr_t: torch.Tensor, current_sample: torch.Tensor) -> torch.Tensor: + pred = self.transformer( + hidden_states=current_sample, + encoder_hidden_states=text_condition, + encoder_attention_mask=text_mask, + timestep=curr_t.expand(batch_size), + attention_mask=mask, + latent_cond=latent_cond, + ).sample + if guidance_scale <= 1.0: + return pred + null_pred = self.transformer( + hidden_states=current_sample, + encoder_hidden_states=neg_text, + encoder_attention_mask=neg_text_mask, + timestep=curr_t.expand(batch_size), + attention_mask=mask, + latent_cond=latent_cond, + ).sample + return null_pred + (pred - null_pred) * guidance_scale + + for idx in range(len(timesteps) - 1): + curr_t = timesteps[idx] + dt = timesteps[idx + 1] - timesteps[idx] + sample = sample + model_step(curr_t, sample) * dt + + if output_type == "latent": + if not return_dict: + return (sample,) + return AudioPipelineOutput(audios=sample) + + waveform = self.vae.decode(sample.permute(0, 2, 1)).sample + if output_type == "np": + waveform = waveform.cpu().float().numpy() + elif output_type != "pt": + raise ValueError(f"Unsupported output_type: {output_type}") + + if not return_dict: + return (waveform,) + return AudioPipelineOutput(audios=waveform) diff --git a/tests/models/transformers/test_models_transformer_longcat_audio_dit.py b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py new file mode 100644 index 000000000000..858f7e8484d1 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_longcat_audio_dit.py @@ -0,0 +1,54 @@ +import torch + +from diffusers import LongCatAudioDiTTransformer + + +def test_longcat_audio_transformer_forward_shape(): + model = LongCatAudioDiTTransformer( + dit_dim=64, + dit_depth=2, + dit_heads=4, + dit_text_dim=32, + latent_dim=8, + text_conv=False, + ) + hidden_states = torch.randn(2, 16, 8) + encoder_hidden_states = torch.randn(2, 10, 32) + encoder_attention_mask = torch.ones(2, 10, dtype=torch.bool) + timestep = torch.tensor([1.0, 1.0]) + + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + ) + + assert output.sample.shape == hidden_states.shape + + +def test_longcat_audio_transformer_masked_forward(): + model = LongCatAudioDiTTransformer( + dit_dim=64, + dit_depth=2, + dit_heads=4, + dit_text_dim=32, + latent_dim=8, + text_conv=False, + ) + hidden_states = torch.randn(2, 16, 8) + encoder_hidden_states = torch.randn(2, 10, 32) + encoder_attention_mask = torch.tensor([[1] * 10, [1] * 6 + [0] * 4], dtype=torch.bool) + attention_mask = torch.tensor([[1] * 16, [1] * 9 + [0] * 7], dtype=torch.bool) + timestep = torch.tensor([1.0, 1.0]) + + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + attention_mask=attention_mask, + ) + + assert output.sample.shape == hidden_states.shape + assert torch.all(output.sample[1, 9:] == 0) diff --git a/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py new file mode 100644 index 000000000000..ce16e9e26aab --- /dev/null +++ b/tests/pipelines/longcat_audio_dit/test_longcat_audio_dit.py @@ -0,0 +1,256 @@ +# Copyright 2026 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import unittest +from pathlib import Path + +import torch +from safetensors.torch import save_file +from transformers import UMT5Config, UMT5EncoderModel + +from diffusers import LongCatAudioDiTPipeline, LongCatAudioDiTTransformer, LongCatAudioDiTVae +from tests.testing_utils import require_torch_accelerator, slow, torch_device + + +class DummyTokenizer: + model_max_length = 16 + + def __call__(self, texts, padding="longest", truncation=True, max_length=None, return_tensors="pt"): + if isinstance(texts, str): + texts = [texts] + batch = len(texts) + return type( + "TokenBatch", + (), + { + "input_ids": torch.ones(batch, 4, dtype=torch.long), + "attention_mask": torch.ones(batch, 4, dtype=torch.long), + }, + ) + + +class LongCatAudioDiTPipelineFastTests(unittest.TestCase): + pipeline_class = LongCatAudioDiTPipeline + + def get_dummy_components(self): + torch.manual_seed(0) + text_encoder = UMT5EncoderModel(UMT5Config(d_model=32, num_layers=1, num_heads=4, d_ff=64, vocab_size=128)) + transformer = LongCatAudioDiTTransformer( + dit_dim=64, + dit_depth=2, + dit_heads=4, + dit_text_dim=32, + latent_dim=8, + text_conv=False, + ) + vae = LongCatAudioDiTVae( + in_channels=1, + channels=16, + c_mults=[1, 2], + strides=[2], + latent_dim=8, + encoder_latent_dim=16, + downsampling_ratio=2, + sample_rate=24000, + ) + + return { + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": DummyTokenizer(), + "transformer": transformer, + } + + def get_dummy_inputs(self, device, seed=0, prompt="soft ocean ambience"): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + return { + "prompt": prompt, + "audio_end_in_s": 0.1, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "generator": generator, + "output_type": "pt", + } + + def test_inference(self): + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)).audios + + self.assertEqual(output.ndim, 3) + self.assertEqual(output.shape[0], 1) + self.assertEqual(output.shape[1], 1) + self.assertGreater(output.shape[-1], 0) + + def test_inference_batch_single_identical(self): + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + output1 = pipe(**self.get_dummy_inputs(device, seed=42)).audios + output2 = pipe(**self.get_dummy_inputs(device, seed=42)).audios + + self.assertTrue(torch.allclose(output1, output2, atol=1e-4)) + + def test_inference_batch_multiple_prompts(self): + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device=device).manual_seed(42) + output = pipe( + prompt=["soft ocean ambience", "gentle rain ambience"], + audio_end_in_s=0.1, + num_inference_steps=2, + guidance_scale=1.0, + generator=generator, + output_type="pt", + ).audios + + self.assertEqual(output.ndim, 3) + self.assertEqual(output.shape[0], 2) + self.assertEqual(output.shape[1], 1) + self.assertGreater(output.shape[-1], 0) + + def test_save_pretrained_roundtrip(self): + import tempfile + + device = "cpu" + pipe = self.pipeline_class(**self.get_dummy_components()) + pipe.to(device) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipe.save_pretrained(tmp_dir) + reloaded = self.pipeline_class.from_pretrained(tmp_dir, tokenizer=DummyTokenizer(), local_files_only=True) + output = reloaded(**self.get_dummy_inputs(device, seed=0)).audios + + self.assertIsInstance(reloaded, LongCatAudioDiTPipeline) + self.assertEqual(output.ndim, 3) + self.assertGreater(output.shape[-1], 0) + + def test_from_pretrained_local_dir(self): + import tempfile + from unittest.mock import patch + + device = "cpu" + components = self.get_dummy_components() + text_encoder = components["text_encoder"] + transformer = components["transformer"] + vae = components["vae"] + + with tempfile.TemporaryDirectory() as tmp_dir: + model_dir = Path(tmp_dir) / "longcat-audio-dit" + model_dir.mkdir() + + config = { + "dit_dim": 64, + "dit_depth": 2, + "dit_heads": 4, + "dit_text_dim": 32, + "latent_dim": 8, + "dit_dropout": 0.0, + "dit_bias": True, + "dit_cross_attn": True, + "dit_adaln_type": "global", + "dit_adaln_use_text_cond": True, + "dit_long_skip": True, + "dit_text_conv": False, + "dit_qk_norm": True, + "dit_cross_attn_norm": False, + "dit_eps": 1e-6, + "dit_use_latent_condition": True, + "sampling_rate": 24000, + "latent_hop": 2, + "max_wav_duration": 30.0, + "text_norm_feat": True, + "text_add_embed": True, + "text_encoder_model": "dummy-umt5", + "text_encoder_config": text_encoder.config.to_dict(), + "vae_config": {**dict(vae.config), "model_type": "longcat_audio_dit_vae"}, + } + with (model_dir / "config.json").open("w") as handle: + json.dump(config, handle) + + state_dict = {} + state_dict.update({f"text_encoder.{k}": v for k, v in text_encoder.state_dict().items() if k != "shared.weight"}) + state_dict.update({f"transformer.{k}": v for k, v in transformer.state_dict().items()}) + state_dict.update({f"vae.{k}": v for k, v in vae.state_dict().items()}) + save_file(state_dict, model_dir / "model.safetensors") + + with patch( + "diffusers.pipelines.longcat_audio_dit.pipeline_longcat_audio_dit.AutoTokenizer.from_pretrained", + return_value=DummyTokenizer(), + ): + pipe = LongCatAudioDiTPipeline.from_pretrained(model_dir, local_files_only=True) + + output = pipe(**self.get_dummy_inputs(device, seed=0)).audios + + self.assertIsInstance(pipe, LongCatAudioDiTPipeline) + self.assertEqual(pipe.sample_rate, 24000) + self.assertEqual(pipe.latent_hop, 2) + self.assertEqual(output.ndim, 3) + self.assertGreater(output.shape[-1], 0) + + +def test_longcat_audio_top_level_imports(): + assert LongCatAudioDiTPipeline is not None + assert LongCatAudioDiTTransformer is not None + assert LongCatAudioDiTVae is not None + + +@slow +@require_torch_accelerator +def test_longcat_audio_pipeline_from_pretrained_real_local_weights(): + model_path = Path(os.getenv("LONGCAT_AUDIO_DIT_MODEL_PATH", "/data/models/meituan-longcat/LongCat-AudioDiT-1B")) + tokenizer_path_env = os.getenv("LONGCAT_AUDIO_DIT_TOKENIZER_PATH") + if tokenizer_path_env is None: + raise unittest.SkipTest("LONGCAT_AUDIO_DIT_TOKENIZER_PATH is not set") + tokenizer_path = Path(tokenizer_path_env) + + if not model_path.exists(): + raise unittest.SkipTest(f"LongCat-AudioDiT model path not found: {model_path}") + if not tokenizer_path.exists(): + raise unittest.SkipTest(f"LongCat-AudioDiT tokenizer path not found: {tokenizer_path}") + + pipe = LongCatAudioDiTPipeline.from_pretrained( + model_path, + tokenizer=tokenizer_path, + torch_dtype=torch.float16, + local_files_only=True, + ) + pipe = pipe.to(torch_device) + + result = pipe( + prompt="A calm ocean wave ambience with soft wind in the background.", + audio_end_in_s=2.0, + num_inference_steps=2, + guidance_scale=4.0, + output_type="pt", + ) + + assert result.audios.ndim == 3 + assert result.audios.shape[0] == 1 + assert result.audios.shape[1] == 1 + assert result.audios.shape[-1] > 0