Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,68 @@ def __getitem__(self, index):
return example


# These helpers only matter for prior preservation, where instance and class prompt
# embedding batches are concatenated and may not share the same mask/sequence length.
def _materialize_prompt_embedding_mask(
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None
) -> torch.Tensor:
"""Return a dense mask tensor for a prompt embedding batch."""
batch_size, seq_len = prompt_embeds.shape[:2]

if prompt_embeds_mask is None:
return torch.ones((batch_size, seq_len), dtype=torch.long, device=prompt_embeds.device)

if prompt_embeds_mask.shape != (batch_size, seq_len):
raise ValueError(
f"`prompt_embeds_mask` shape {prompt_embeds_mask.shape} must match prompt embeddings shape "
f"({batch_size}, {seq_len})."
)

return prompt_embeds_mask.to(device=prompt_embeds.device)


def _pad_prompt_embedding_pair(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would maybe also add a comment on how this is only relevant during prior preservation.

Copy link
Copy Markdown
Author

@chenyangzhu1 chenyangzhu1 Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

I added the following comment

# These helpers only matter for prior preservation, where instance and class prompt
# embedding batches are concatenated and may not share the same mask/sequence length.

prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None, target_seq_len: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Pad one prompt embedding batch and its mask to a shared sequence length."""
prompt_embeds_mask = _materialize_prompt_embedding_mask(prompt_embeds, prompt_embeds_mask)
pad_width = target_seq_len - prompt_embeds.shape[1]

if pad_width <= 0:
return prompt_embeds, prompt_embeds_mask

prompt_embeds = torch.cat(
[prompt_embeds, prompt_embeds.new_zeros(prompt_embeds.shape[0], pad_width, prompt_embeds.shape[2])], dim=1
)
prompt_embeds_mask = torch.cat(
[prompt_embeds_mask, prompt_embeds_mask.new_zeros(prompt_embeds_mask.shape[0], pad_width)], dim=1
)

return prompt_embeds, prompt_embeds_mask


def concat_prompt_embedding_batches(
*prompt_embedding_pairs: tuple[torch.Tensor, torch.Tensor | None],
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Concatenate prompt embedding batches while handling missing masks and length mismatches."""
if not prompt_embedding_pairs:
raise ValueError("At least one prompt embedding pair must be provided.")

target_seq_len = max(prompt_embeds.shape[1] for prompt_embeds, _ in prompt_embedding_pairs)
padded_pairs = [
_pad_prompt_embedding_pair(prompt_embeds, prompt_embeds_mask, target_seq_len)
for prompt_embeds, prompt_embeds_mask in prompt_embedding_pairs
]

merged_prompt_embeds = torch.cat([prompt_embeds for prompt_embeds, _ in padded_pairs], dim=0)
merged_mask = torch.cat([prompt_embeds_mask for _, prompt_embeds_mask in padded_pairs], dim=0)

if merged_mask.all():
return merged_prompt_embeds, None

return merged_prompt_embeds, merged_mask


def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
Expand Down Expand Up @@ -1320,8 +1382,10 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
prompt_embeds = instance_prompt_embeds
prompt_embeds_mask = instance_prompt_embeds_mask
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0)
prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0)
prompt_embeds, prompt_embeds_mask = concat_prompt_embedding_batches(
(instance_prompt_embeds, instance_prompt_embeds_mask),
(class_prompt_embeds, class_prompt_embeds_mask),
)

# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,12 @@
]
)
_import_structure["image_processor"] = [
"IPAdapterMaskProcessor",
"InpaintProcessor",
"IPAdapterMaskProcessor",
"PixArtImageProcessor",
"VaeImageProcessor",
"VaeImageProcessorLDM3D",
]
_import_structure["video_processor"] = ["VideoProcessor"]
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
Expand Down Expand Up @@ -396,6 +395,7 @@
]
)
_import_structure["training_utils"] = ["EMAModel"]
_import_structure["video_processor"] = ["VideoProcessor"]

try:
if not (is_torch_available() and is_scipy_available()):
Expand Down