From 3a00206cef588e85b51acce43f7718991c6ecd29 Mon Sep 17 00:00:00 2001 From: Alexey Zolotenkov Date: Sat, 4 Apr 2026 19:55:47 +0200 Subject: [PATCH] Fix Flux2 DreamBooth prior preservation prompt repeats --- examples/dreambooth/train_dreambooth_lora_flux2.py | 12 ++++++++---- .../dreambooth/train_dreambooth_lora_flux2_klein.py | 12 ++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 24ba5d507328..53aaedb36897 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -1740,9 +1740,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt_embeds = prompt_embeds_cache[step] text_ids = text_ids_cache[step] else: - num_repeat_elements = len(prompts) - prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) - text_ids = text_ids.repeat(num_repeat_elements, 1, 1) + # With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries, + # while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along + # dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...]. + num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) + prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0) + text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0) # Convert images to latent space if args.cache_latents: @@ -1809,10 +1812,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) + weighting, weighting_prior = torch.chunk(weighting, 2, dim=0) # Compute prior loss prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + (weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( target_prior.shape[0], -1 ), 1, diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 942c1317e3a8..fcbfeef490cb 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -1680,9 +1680,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): prompt_embeds = prompt_embeds_cache[step] text_ids = text_ids_cache[step] else: - num_repeat_elements = len(prompts) - prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1) - text_ids = text_ids.repeat(num_repeat_elements, 1, 1) + # With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries, + # while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along + # dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...]. + num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts) + prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0) + text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0) # Convert images to latent space if args.cache_latents: @@ -1752,10 +1755,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) + weighting, weighting_prior = torch.chunk(weighting, 2, dim=0) # Compute prior loss prior_loss = torch.mean( - (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + (weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( target_prior.shape[0], -1 ), 1,