From 694574541151e0738055fcb03ce9314ede424a3d Mon Sep 17 00:00:00 2001 From: chenyangzhu1 Date: Thu, 2 Apr 2026 20:36:54 +0800 Subject: [PATCH 1/4] Handle prompt embedding concat in Qwen dreambooth example --- .../train_dreambooth_lora_qwen_image.py | 67 ++++++++++++++++++- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index a1e2fa0f6052..5ef80413e18a 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -91,6 +91,7 @@ if is_wandb_available(): import wandb + wandb.init(project="test", mode="offline") # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.38.0.dev0") @@ -906,6 +907,66 @@ def __getitem__(self, index): return example + +def _materialize_prompt_embedding_mask( + prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None +) -> torch.Tensor: + """Return a dense mask tensor for a prompt embedding batch.""" + batch_size, seq_len = prompt_embeds.shape[:2] + + if prompt_embeds_mask is None: + return torch.ones((batch_size, seq_len), dtype=torch.long, device=prompt_embeds.device) + + if prompt_embeds_mask.shape != (batch_size, seq_len): + raise ValueError( + f"`prompt_embeds_mask` shape {prompt_embeds_mask.shape} must match prompt embeddings shape " + f"({batch_size}, {seq_len})." + ) + + return prompt_embeds_mask.to(device=prompt_embeds.device) + + +def _pad_prompt_embedding_pair( + prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None, target_seq_len: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Pad one prompt embedding batch and its mask to a shared sequence length.""" + prompt_embeds_mask = _materialize_prompt_embedding_mask(prompt_embeds, prompt_embeds_mask) + pad_width = target_seq_len - prompt_embeds.shape[1] + + if pad_width <= 0: + return prompt_embeds, prompt_embeds_mask + + prompt_embeds = torch.cat( + [prompt_embeds, prompt_embeds.new_zeros(prompt_embeds.shape[0], pad_width, prompt_embeds.shape[2])], dim=1 + ) + prompt_embeds_mask = torch.cat( + [prompt_embeds_mask, prompt_embeds_mask.new_zeros(prompt_embeds_mask.shape[0], pad_width)], dim=1 + ) + + return prompt_embeds, prompt_embeds_mask + + +def concat_prompt_embedding_batches( + *prompt_embedding_pairs: tuple[torch.Tensor, torch.Tensor | None], +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Concatenate prompt embedding batches while handling missing masks and length mismatches.""" + if not prompt_embedding_pairs: + raise ValueError("At least one prompt embedding pair must be provided.") + + target_seq_len = max(prompt_embeds.shape[1] for prompt_embeds, _ in prompt_embedding_pairs) + padded_pairs = [ + _pad_prompt_embedding_pair(prompt_embeds, prompt_embeds_mask, target_seq_len) + for prompt_embeds, prompt_embeds_mask in prompt_embedding_pairs + ] + + merged_prompt_embeds = torch.cat([prompt_embeds for prompt_embeds, _ in padded_pairs], dim=0) + merged_mask = torch.cat([prompt_embeds_mask for _, prompt_embeds_mask in padded_pairs], dim=0) + + if merged_mask.all(): + return merged_prompt_embeds, None + + return merged_prompt_embeds, merged_mask + def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( @@ -1320,8 +1381,10 @@ def compute_text_embeddings(prompt, text_encoding_pipeline): prompt_embeds = instance_prompt_embeds prompt_embeds_mask = instance_prompt_embeds_mask if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0) - prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0) + prompt_embeds, prompt_embeds_mask = concat_prompt_embedding_batches( + (instance_prompt_embeds, instance_prompt_embeds_mask), + (class_prompt_embeds, class_prompt_embeds_mask), + ) # if cache_latents is set to True, we encode images to latents and store them. # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided From a99dad60cd7139e88d8154ceb7ef4f9673c93014 Mon Sep 17 00:00:00 2001 From: chenyangzhu1 Date: Thu, 2 Apr 2026 23:01:45 +0800 Subject: [PATCH 2/4] remove wandb config --- examples/dreambooth/train_dreambooth_lora_qwen_image.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 5ef80413e18a..5eb2286db7ac 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -91,7 +91,6 @@ if is_wandb_available(): import wandb - wandb.init(project="test", mode="offline") # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.38.0.dev0") From eb29be2a8e2477752d95d5bb65aea1421b2a78ad Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 3 Apr 2026 05:49:30 +0000 Subject: [PATCH 3/4] Apply style fixes --- examples/dreambooth/train_dreambooth_lora_qwen_image.py | 2 +- src/diffusers/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 5eb2286db7ac..95b39e00a273 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -906,7 +906,6 @@ def __getitem__(self, index): return example - def _materialize_prompt_embedding_mask( prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None ) -> torch.Tensor: @@ -966,6 +965,7 @@ def concat_prompt_embedding_batches( return merged_prompt_embeds, merged_mask + def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0f74c0bbcb4a..e220eb10d562 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -178,13 +178,12 @@ ] ) _import_structure["image_processor"] = [ - "IPAdapterMaskProcessor", "InpaintProcessor", + "IPAdapterMaskProcessor", "PixArtImageProcessor", "VaeImageProcessor", "VaeImageProcessorLDM3D", ] - _import_structure["video_processor"] = ["VideoProcessor"] _import_structure["models"].extend( [ "AllegroTransformer3DModel", @@ -396,6 +395,7 @@ ] ) _import_structure["training_utils"] = ["EMAModel"] + _import_structure["video_processor"] = ["VideoProcessor"] try: if not (is_torch_available() and is_scipy_available()): From 410409ffe7905eaae8717666eb5215c30d1c71ea Mon Sep 17 00:00:00 2001 From: chenyangzhu1 Date: Fri, 3 Apr 2026 14:07:40 +0800 Subject: [PATCH 4/4] add a comment on how this is only relevant during prior preservation. --- examples/dreambooth/train_dreambooth_lora_qwen_image.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index 95b39e00a273..245aed575c35 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -906,6 +906,8 @@ def __getitem__(self, index): return example +# These helpers only matter for prior preservation, where instance and class prompt +# embedding batches are concatenated and may not share the same mask/sequence length. def _materialize_prompt_embedding_mask( prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None ) -> torch.Tensor: