diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 67f0bff38fbf..a6f33a4a478c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -737,6 +737,8 @@ title: LCMScheduler - local: api/schedulers/lms_discrete title: LMSDiscreteScheduler + - local: api/schedulers/ltx_euler_ancestral_rf + title: LTXEulerAncestralRFScheduler - local: api/schedulers/pndm title: PNDMScheduler - local: api/schedulers/repaint diff --git a/docs/source/en/api/schedulers/ltx_euler_ancestral_rf.md b/docs/source/en/api/schedulers/ltx_euler_ancestral_rf.md new file mode 100644 index 000000000000..9d4e8d7a750c --- /dev/null +++ b/docs/source/en/api/schedulers/ltx_euler_ancestral_rf.md @@ -0,0 +1,45 @@ + + +# LTXEulerAncestralRFScheduler + +The `LTXEulerAncestralRFScheduler` implements a K-diffusion-style Euler-Ancestral sampler +for flow / CONST parameterization, closely mirroring ComfyUI's `sample_euler_ancestral_RF` +implementation used for [LTX-Video](https://huggingface.co/docs/diffusers/api/pipelines/ltx_video). + +The scheduler operates on a normalized sigma schedule σ ∈ [0, 1] and reconstructs the clean +estimate as `x0 = x_t − σ_t · v_t` (CONST parametrization). Stochastic noise reinjection is +controlled by `eta` (`eta=0` gives a deterministic Euler step; `eta=1` matches ComfyUI's +default RF behavior). + +This scheduler is used by [`LTXPipeline`], [`LTXImageToVideoPipeline`], and +[`LTXConditionPipeline`]. + +The `eta` parameter must be >= 0. `eta=0` gives a deterministic (DDIM-like) Euler step; +`eta=1` matches ComfyUI's default RF behavior. Values above 1 are accepted but trigger a +one-time warning when the schedule step is too coarse to keep `sigma_down` non-negative. + + + +See also [`FlowMatchEulerDiscreteScheduler`], which this scheduler delegates to for +auto-generated sigma schedules and shares config compatibility with via `_compatibles`. + + + +## LTXEulerAncestralRFScheduler +[[autodoc]] LTXEulerAncestralRFScheduler + +## LTXEulerAncestralRFSchedulerOutput +[[autodoc]] schedulers.scheduling_ltx_euler_ancestral_rf.LTXEulerAncestralRFSchedulerOutput diff --git a/src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py b/src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py index 453c8515c301..8676b21f230c 100644 --- a/src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py +++ b/src/diffusers/schedulers/scheduling_ltx_euler_ancestral_rf.py @@ -1,4 +1,4 @@ -# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved. +# Copyright 2025 Lightricks, Vittoria Lanzo and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,8 +65,9 @@ class LTXEulerAncestralRFScheduler(SchedulerMixin, ConfigMixin): num_train_timesteps (`int`, defaults to 1000): Included for config compatibility; not used to build the schedule. eta (`float`, defaults to 1.0): - Stochasticity parameter. `eta=0.0` yields deterministic DDIM-like sampling; `eta=1.0` matches ComfyUI's - default RF behavior. + Stochasticity parameter. Must be >= 0. `eta=0.0` yields deterministic DDIM-like sampling; `eta=1.0` + matches ComfyUI's default RF behavior. Values above 1.0 are accepted but will trigger clamping of + `sigma_down` to [0, sigma_next] with a one-time warning when the schedule step is too coarse. s_noise (`float`, defaults to 1.0): Global scaling factor for the stochastic noise term. """ @@ -82,12 +83,15 @@ def __init__( eta: float = 1.0, s_noise: float = 1.0, ): + if eta < 0: + raise ValueError(f"`eta` must be >= 0, got {eta}.") # Note: num_train_timesteps is kept only for config compatibility. self.num_inference_steps: int = None self.sigmas: torch.Tensor | None = None self.timesteps: torch.Tensor | None = None self._step_index: int = None self._begin_index: int = None + self._sigma_down_warned: bool = False # deduplication flag for sigma_down clamp warning @property def step_index(self) -> int: @@ -233,12 +237,23 @@ def set_timesteps( if sigmas_tensor.ndim != 1: raise ValueError(f"`sigmas` must be a 1D tensor, got shape {tuple(sigmas_tensor.shape)}.") + if sigmas_tensor[0].item() > 1.0 + 1e-6: + raise ValueError( + f"`sigmas` values must be in [0, 1] for RF/CONST parameterization, " + f"got max={sigmas_tensor[0].item():.6f}." + ) + + if len(sigmas_tensor) > 1 and not (sigmas_tensor[:-1] >= sigmas_tensor[1:]).all(): + sig_list = sigmas_tensor.tolist() + sig_repr = str(sig_list) if len(sig_list) <= 8 else f"{sig_list[:4]} ... {sig_list[-4:]} (len={len(sig_list)})" + raise ValueError( + f"`sigmas` must be monotonically non-increasing (each entry >= the next), got {sig_repr}" + ) + if sigmas_tensor[-1].abs().item() > 1e-6: logger.warning( - "The last sigma in the schedule is not zero (%.6f). " - "For best compatibility with ComfyUI's RF sampler, the terminal sigma " - "should be 0.0.", - sigmas_tensor[-1].item(), + f"The last sigma in the schedule is not zero ({sigmas_tensor[-1].item():.6f}). " + f"For best compatibility with ComfyUI's RF sampler, the terminal sigma should be 0.0." ) # Move to device once, then derive timesteps. @@ -256,10 +271,8 @@ def set_timesteps( if num_inference_steps is not None and num_inference_steps != len(sigmas) - 1: logger.warning( - "Provided `num_inference_steps=%d` does not match `len(sigmas)-1=%d`. " - "Overriding `num_inference_steps` with `len(sigmas)-1`.", - num_inference_steps, - len(sigmas) - 1, + f"Provided `num_inference_steps={num_inference_steps}` does not match `len(sigmas)-1={len(sigmas) - 1}`. " + f"Overriding `num_inference_steps` with `len(sigmas)-1`." ) self.num_inference_steps = len(sigmas) - 1 @@ -345,6 +358,20 @@ def step( downstep_ratio = 1.0 + (sigma_next / sigma - 1.0) * eta sigma_down = sigma_next * downstep_ratio + # sigma_down can go negative when eta > 1 on a coarse schedule step, which + # flips sigma_ratio and corrupts the Euler update. Clamp to [0, +inf) and + # emit a one-time warning so the user knows to reduce eta or refine the schedule. + # (sigma_down > sigma_next is not reachable under a valid monotone schedule.) + if sigma_down.item() < 0: + if not self._sigma_down_warned: + logger.warning( + f"`eta`={eta:.3f} caused `sigma_down`={sigma_down.item():.6f} to go negative " + f"(sigma={sigma.item():.6f}, sigma_next={sigma_next.item():.6f}). " + f"Clamping to 0. Reduce `eta` or use a finer schedule to avoid this." + ) + self._sigma_down_warned = True + sigma_down = sigma_down.clamp(min=0.0) + alpha_ip1 = 1.0 - sigma_next alpha_down = 1.0 - sigma_down diff --git a/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py new file mode 100644 index 000000000000..18e5df292bbf --- /dev/null +++ b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py @@ -0,0 +1,222 @@ +# Copyright 2025 Vittoria Lanzo and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTXEulerAncestralRFScheduler + + +def _make_scheduler(**kwargs): + config = {"num_train_timesteps": 1000, "eta": 1.0, "s_noise": 1.0} + config.update(kwargs) + return LTXEulerAncestralRFScheduler(**config) + + +def _linear_sigmas(n=4): + """Return a monotonically decreasing sigma schedule with terminal 0.""" + return [round(1.0 - i / n, 6) for i in range(n + 1)] + + +class LTXEulerAncestralRFSchedulerTest(unittest.TestCase): + # ------------------------------------------------------------------ + # set_timesteps: input validation + # ------------------------------------------------------------------ + + def test_set_timesteps_explicit_sigmas_valid(self): + scheduler = _make_scheduler() + scheduler.set_timesteps(sigmas=_linear_sigmas(4)) + self.assertEqual(scheduler.num_inference_steps, 4) + self.assertEqual(len(scheduler.sigmas), 5) + + def test_set_timesteps_non_monotone_raises(self): + """ + Non-monotonically-decreasing sigmas must raise ValueError. + Without this check, step() computes sigma_down outside [0, 1] + and sigma_ratio >> 1, silently amplifying the latent. + """ + scheduler = _make_scheduler() + # sigma increases at step 0 -> 1 + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[0.2, 0.8, 0.5, 0.0]) + + def test_set_timesteps_fully_ascending_raises(self): + scheduler = _make_scheduler() + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[0.0, 0.5, 1.0]) + + def test_set_timesteps_plateau_is_valid(self): + """Equal consecutive sigmas (plateau steps) must NOT raise — used in img2img partial schedules.""" + scheduler = _make_scheduler() + # plateau at the first two entries is intentional in some set_begin_index workflows + scheduler.set_timesteps(sigmas=[1.0, 1.0, 0.5, 0.0]) + self.assertEqual(scheduler.num_inference_steps, 3) + + def test_set_timesteps_num_inference_steps_auto(self): + """Auto-generated schedule (no explicit sigmas) must initialise correctly.""" + scheduler = _make_scheduler() + scheduler.set_timesteps(num_inference_steps=10) + self.assertEqual(scheduler.num_inference_steps, 10) + self.assertEqual(len(scheduler.sigmas), 11) # N steps + terminal 0 + # Verify the auto-generated schedule is itself monotone + sigmas = scheduler.sigmas + self.assertTrue( + (sigmas[:-1] >= sigmas[1:]).all(), + "Auto-generated sigma schedule is not monotonically non-increasing.", + ) + + # ------------------------------------------------------------------ + # step(): output invariants + # ------------------------------------------------------------------ + + def test_step_output_dtype_fp16_preserved(self): + """prev_sample.dtype must equal sample.dtype for fp16 inputs.""" + scheduler = _make_scheduler() + scheduler.set_timesteps(sigmas=_linear_sigmas(4)) + sample = torch.randn(1, 4, 8, 8, dtype=torch.float16) + model_output = torch.randn_like(sample) + out = scheduler.step(model_output, scheduler.timesteps[0], sample) + self.assertEqual(out.prev_sample.dtype, torch.float16) + + def test_step_output_dtype_fp32_preserved(self): + """prev_sample.dtype must equal sample.dtype for fp32 inputs.""" + scheduler = _make_scheduler() + scheduler.set_timesteps(sigmas=_linear_sigmas(4)) + sample = torch.randn(1, 4, 8, 8, dtype=torch.float32) + model_output = torch.randn_like(sample) + out = scheduler.step(model_output, scheduler.timesteps[0], sample) + self.assertEqual(out.prev_sample.dtype, torch.float32) + + def test_step_output_shape_preserved(self): + """prev_sample.shape must equal sample.shape.""" + scheduler = _make_scheduler() + scheduler.set_timesteps(sigmas=_linear_sigmas(4)) + sample = torch.randn(2, 4, 16, 16) + model_output = torch.randn_like(sample) + out = scheduler.step(model_output, scheduler.timesteps[0], sample) + self.assertEqual(out.prev_sample.shape, sample.shape) + + def test_step_return_tuple(self): + """return_dict=False must return a tuple whose first element matches return_dict=True.""" + scheduler = _make_scheduler() + scheduler.set_timesteps(sigmas=_linear_sigmas(4)) + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + t = scheduler.timesteps[0] + + torch.manual_seed(0) + out_dict = scheduler.step(model_output, t, sample, return_dict=True) + scheduler._step_index = None # reset step index to replay the same step + torch.manual_seed(0) + out_tuple = scheduler.step(model_output, t, sample, return_dict=False) + + self.assertIsInstance(out_tuple, tuple) + self.assertTrue(torch.allclose(out_dict.prev_sample, out_tuple[0])) + + def test_step_eta_zero_is_deterministic(self): + """ + With eta=0 no noise is injected; the output must be identical regardless + of the generator seed passed. + """ + scheduler = _make_scheduler(eta=0.0) + scheduler.set_timesteps(sigmas=_linear_sigmas(4)) + sample = torch.randn(1, 4, 8, 8, generator=torch.Generator().manual_seed(0)) + model_output = torch.randn(1, 4, 8, 8, generator=torch.Generator().manual_seed(1)) + t = scheduler.timesteps[0] + + out1 = scheduler.step(model_output, t, sample).prev_sample + + scheduler._step_index = None + out2 = scheduler.step( + model_output, t, sample, generator=torch.Generator().manual_seed(99) + ).prev_sample + + self.assertTrue(torch.allclose(out1, out2), "eta=0 step should be fully deterministic.") + + def test_step_final_step_returns_denoised(self): + """At sigma=0 (final denoising step) prev_sample must equal the denoised estimate.""" + scheduler = _make_scheduler(eta=1.0) + # Two-step schedule: [0.5, 0.0] + scheduler.set_timesteps(sigmas=[0.5, 0.0]) + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + + # First (and only real) step + out = scheduler.step(model_output, scheduler.timesteps[0], sample) + # At sigma_next=0 the scheduler must return the clean estimate x0 = x_t - sigma*v_t + expected = sample - 0.5 * model_output + self.assertTrue(torch.allclose(out.prev_sample, expected, atol=1e-5)) + + def test_set_timesteps_sigma_above_one_raises(self): + """Sigmas outside [0, 1] violate the RF/CONST parameterization assumption.""" + scheduler = _make_scheduler() + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[2.0, 1.0, 0.5, 0.0]) + + def test_step_eta_negative_raises(self): + """eta < 0 is invalid and must raise ValueError at construction time.""" + with self.assertRaises(ValueError): + _make_scheduler(eta=-0.1) + + def test_step_eta_greater_than_one_clamps_sigma_down(self): + """eta > 1 on a coarse schedule pushes sigma_down < 0; must clamp, warn once, and stay finite.""" + scheduler = _make_scheduler(eta=2.0) + # Coarse schedule: large step size maximises the chance sigma_down goes negative + scheduler.set_timesteps(sigmas=[0.5, 0.1, 0.0]) + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + self.assertFalse(scheduler._sigma_down_warned) + + out = scheduler.step(model_output, scheduler.timesteps[0], sample) + + # Warning flag must be set (warning was emitted) + self.assertTrue(scheduler._sigma_down_warned) + # Output must be finite (clamp prevented NaN/Inf from negative sigma_down) + self.assertTrue(torch.isfinite(out.prev_sample).all()) + + # Second step must NOT re-emit (deduplication) + scheduler._sigma_down_warned_count_before = True # flag already True + out2 = scheduler.step(model_output, scheduler.timesteps[1], sample) + self.assertTrue(torch.isfinite(out2.prev_sample).all()) + + def test_step_index_advances(self): + """_step_index must increment by 1 on each call.""" + scheduler = _make_scheduler() + scheduler.set_timesteps(sigmas=_linear_sigmas(4)) + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + + for expected_idx in range(4): + scheduler.step(model_output, scheduler.timesteps[expected_idx], sample) + self.assertEqual(scheduler._step_index, expected_idx + 1) + + def test_step_beyond_end_returns_sample(self): + """Calling step() past the last index must return the input sample unchanged.""" + scheduler = _make_scheduler(eta=0.0) + scheduler.set_timesteps(sigmas=[0.5, 0.0]) + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + + # Consume all steps normally + scheduler.step(model_output, scheduler.timesteps[0], sample) + # Force _step_index to the clamped maximum + scheduler._step_index = len(scheduler.sigmas) - 1 + # A further call must not crash and must return a finite tensor + out = scheduler.step(model_output, scheduler.timesteps[-1], sample) + self.assertTrue(torch.isfinite(out.prev_sample).all()) + + +if __name__ == "__main__": + unittest.main()