diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index d29bb5b2593c..d2fbc4ee7322 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -345,6 +345,9 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device = N ValueError: If `num_inference_steps` is larger than `self.config.num_train_timesteps`. """ + if num_inference_steps <= 0: + raise ValueError("num_inference_steps must be > 0") + if num_inference_steps > self.config.num_train_timesteps: raise ValueError( f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 972c46c6e930..4aa2a3b439b3 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -292,6 +292,8 @@ def set_timesteps( `num_inference_steps` must be `None`. """ + if num_inference_steps is not None and num_inference_steps <= 0: + raise ValueError("num_inference_steps must be > 0") if num_inference_steps is not None and timesteps is not None: raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") diff --git a/tests/schedulers/test_scheduler_ddim.py b/tests/schedulers/test_scheduler_ddim.py index 13b353a44b08..59b1e963c01c 100644 --- a/tests/schedulers/test_scheduler_ddim.py +++ b/tests/schedulers/test_scheduler_ddim.py @@ -174,3 +174,10 @@ def test_full_loop_with_noise(self): assert abs(result_sum.item() - 354.5418) < 1e-2, f" expected result sum 218.4379, but get {result_sum}" assert abs(result_mean.item() - 0.4616) < 1e-3, f" expected result mean 0.2844, but get {result_mean}" + + def test_num_inference_steps_zero_raises(self): + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=0) \ No newline at end of file diff --git a/tests/schedulers/test_scheduler_ddpm.py b/tests/schedulers/test_scheduler_ddpm.py index 056b5d83350e..a7ba4c93a8a3 100644 --- a/tests/schedulers/test_scheduler_ddpm.py +++ b/tests/schedulers/test_scheduler_ddpm.py @@ -220,3 +220,10 @@ def test_full_loop_with_noise(self): assert abs(result_sum.item() - 387.9466) < 1e-2, f" expected result sum 387.9466, but get {result_sum}" assert abs(result_mean.item() - 0.5051) < 1e-3, f" expected result mean 0.5051, but get {result_mean}" + + def test_num_inference_steps_zero_raises(self): + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=0) \ No newline at end of file