diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py index 103cca81c6a5..c4b1d46af59d 100644 --- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -143,7 +143,20 @@ def set_begin_index(self, begin_index: int = 0): self._begin_index = begin_index # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs - def precondition_inputs(self, sample, sigma): + def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Precondition the input sample by scaling it according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor to precondition. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The scaled input sample. + """ c_in = self._get_conditioning_c_in(sigma) scaled_sample = sample * c_in return scaled_sample @@ -155,7 +168,27 @@ def precondition_noise(self, sigma): return sigma.atan() / math.pi * 2 # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs - def precondition_outputs(self, sample, model_output, sigma): + def precondition_outputs( + self, + sample: torch.Tensor, + model_output: torch.Tensor, + sigma: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + Precondition the model outputs according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor. + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The denoised sample computed by combining the skip connection and output scaling. + """ sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) @@ -173,13 +206,13 @@ def precondition_outputs(self, sample, model_output, sigma): # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that + need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): - The input sample. - timestep (`int`, *optional*): + The input sample tensor. + timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. Returns: @@ -242,8 +275,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc self.noise_sampler = None # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas - def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + def _compute_karras_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364). + + Args: + ramp (`torch.Tensor`): + A tensor of values in [0, 1] representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed Karras sigma schedule. + """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -254,10 +306,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch. return sigmas # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas - def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Implementation closely follows k-diffusion. - + def _compute_exponential_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Compute the exponential sigma schedule. Implementation closely follows k-diffusion: https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + + Args: + ramp (`torch.Tensor`): + A tensor of values representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed exponential sigma schedule. """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -354,7 +423,10 @@ def dpm_solver_first_order_update( `torch.Tensor`: The sample tensor at the previous timestep. """ - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + sigma_t, sigma_s = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) @@ -540,7 +612,10 @@ def step( [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed() ) self.noise_sampler = BrownianTreeNoiseSampler( - model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed + model_output, + sigma_min=self.config.sigma_min, + sigma_max=self.config.sigma_max, + seed=seed, ) noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to( model_output.device @@ -612,7 +687,18 @@ def add_noise( return noisy_samples # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in - def _get_conditioning_c_in(self, sigma): + def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + """ + Compute the input conditioning factor for the EDM formulation. + + Args: + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `float` or `torch.Tensor`: + The input conditioning factor `c_in`. + """ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) return c_in diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index d4e8ca5e8b18..a573f032cad8 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -175,13 +175,37 @@ def set_begin_index(self, begin_index: int = 0): self._begin_index = begin_index # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs - def precondition_inputs(self, sample, sigma): + def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Precondition the input sample by scaling it according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor to precondition. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The scaled input sample. + """ c_in = self._get_conditioning_c_in(sigma) scaled_sample = sample * c_in return scaled_sample # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise - def precondition_noise(self, sigma): + def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Precondition the noise level by applying a logarithmic transformation. + + Args: + sigma (`float` or `torch.Tensor`): + The sigma (noise level) value to precondition. + + Returns: + `torch.Tensor`: + The preconditioned noise value computed as `0.25 * log(sigma)`. + """ if not isinstance(sigma, torch.Tensor): sigma = torch.tensor([sigma]) @@ -190,7 +214,27 @@ def precondition_noise(self, sigma): return c_noise # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs - def precondition_outputs(self, sample, model_output, sigma): + def precondition_outputs( + self, + sample: torch.Tensor, + model_output: torch.Tensor, + sigma: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + Precondition the model outputs according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor. + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The denoised sample computed by combining the skip connection and output scaling. + """ sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) @@ -208,13 +252,13 @@ def precondition_outputs(self, sample, model_output, sigma): # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that + need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): - The input sample. - timestep (`int`, *optional*): + The input sample tensor. + timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. Returns: @@ -274,8 +318,27 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas - def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + def _compute_karras_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364). + + Args: + ramp (`torch.Tensor`): + A tensor of values in [0, 1] representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed Karras sigma schedule. + """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -286,10 +349,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch. return sigmas # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas - def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Implementation closely follows k-diffusion. - + def _compute_exponential_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Compute the exponential sigma schedule. Implementation closely follows k-diffusion: https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + + Args: + ramp (`torch.Tensor`): + A tensor of values representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed exponential sigma schedule. """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -433,7 +513,10 @@ def dpm_solver_first_order_update( `torch.Tensor`: The sample tensor at the previous timestep. """ - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + sigma_t, sigma_s = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) lambda_t = torch.log(alpha_t) - torch.log(sigma_t) @@ -684,7 +767,10 @@ def step( if self.config.algorithm_type == "sde-dpmsolver++": noise = randn_tensor( - model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, ) else: noise = None @@ -757,7 +843,18 @@ def add_noise( return noisy_samples # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in - def _get_conditioning_c_in(self, sigma): + def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + """ + Compute the input conditioning factor for the EDM formulation. + + Args: + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `float` or `torch.Tensor`: + The input conditioning factor `c_in`. + """ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) return c_in diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py index 2ed05d396514..604d8b3ea6fa 100644 --- a/src/diffusers/schedulers/scheduling_edm_euler.py +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import torch @@ -57,29 +57,28 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): methods the library implements for all schedulers such as loading and saving. Args: - sigma_min (`float`, *optional*, defaults to 0.002): + sigma_min (`float`, *optional*, defaults to `0.002`): Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable range is [0, 10]. - sigma_max (`float`, *optional*, defaults to 80.0): + sigma_max (`float`, *optional*, defaults to `80.0`): Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable range is [0.2, 80.0]. - sigma_data (`float`, *optional*, defaults to 0.5): + sigma_data (`float`, *optional*, defaults to `0.5`): The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1]. - sigma_schedule (`str`, *optional*, defaults to `karras`): - Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper - (https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential - schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl. - num_train_timesteps (`int`, defaults to 1000): + sigma_schedule (`Literal["karras", "exponential"]`, *optional*, defaults to `"karras"`): + Sigma schedule to compute the `sigmas`. By default, we use the schedule introduced in the EDM paper + (https://huggingface.co/papers/2206.00364). The `"exponential"` schedule was incorporated in this model: + https://huggingface.co/stabilityai/cosxl. + num_train_timesteps (`int`, *optional*, defaults to `1000`): The number of diffusion steps to train the model. - prediction_type (`str`, defaults to `epsilon`, *optional*): - Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://huggingface.co/papers/2210.02303) paper). - rho (`float`, *optional*, defaults to 7.0): + prediction_type (`Literal["epsilon", "v_prediction"]`, *optional*, defaults to `"epsilon"`): + Prediction type of the scheduler function. `"epsilon"` predicts the noise of the diffusion process, and + `"v_prediction"` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper). + rho (`float`, *optional*, defaults to `7.0`): The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1]. - final_sigmas_type (`str`, defaults to `"zero"`): + final_sigmas_type (`Literal["zero", "sigma_min"]`, *optional*, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0. """ _compatibles = [] @@ -91,12 +90,12 @@ def __init__( sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, - sigma_schedule: str = "karras", + sigma_schedule: Literal["karras", "exponential"] = "karras", num_train_timesteps: int = 1000, - prediction_type: str = "epsilon", + prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", rho: float = 7.0, - final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" - ): + final_sigmas_type: Literal["zero", "sigma_min"] = "zero", + ) -> None: if sigma_schedule not in ["karras", "exponential"]: raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`") @@ -131,26 +130,41 @@ def __init__( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property - def init_noise_sigma(self): - # standard deviation of the initial noise distribution + def init_noise_sigma(self) -> float: + """ + Return the standard deviation of the initial noise distribution. + + Returns: + `float`: + The initial noise sigma value computed as `(sigma_max**2 + 1) ** 0.5`. + """ return (self.config.sigma_max**2 + 1) ** 0.5 @property - def step_index(self): + def step_index(self) -> Optional[int]: """ - The index counter for current timestep. It will increase 1 after each scheduler step. + Return the index counter for the current timestep. The index will increase by 1 after each scheduler step. + + Returns: + `int` or `None`: + The current step index, or `None` if not yet initialized. """ return self._step_index @property - def begin_index(self): + def begin_index(self) -> Optional[int]: """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + Return the index for the first timestep. This should be set from the pipeline with the `set_begin_index` + method. + + Returns: + `int` or `None`: + The begin index, or `None` if not yet set. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): + def set_begin_index(self, begin_index: int = 0) -> None: """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. @@ -160,12 +174,36 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - def precondition_inputs(self, sample, sigma): + def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Precondition the input sample by scaling it according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor to precondition. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The scaled input sample. + """ c_in = self._get_conditioning_c_in(sigma) scaled_sample = sample * c_in return scaled_sample - def precondition_noise(self, sigma): + def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Precondition the noise level by applying a logarithmic transformation. + + Args: + sigma (`float` or `torch.Tensor`): + The sigma (noise level) value to precondition. + + Returns: + `torch.Tensor`: + The preconditioned noise value computed as `0.25 * log(sigma)`. + """ if not isinstance(sigma, torch.Tensor): sigma = torch.tensor([sigma]) @@ -173,7 +211,27 @@ def precondition_noise(self, sigma): return c_noise - def precondition_outputs(self, sample, model_output, sigma): + def precondition_outputs( + self, + sample: torch.Tensor, + model_output: torch.Tensor, + sigma: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + Precondition the model outputs according to the EDM formulation. + + Args: + sample (`torch.Tensor`): + The input sample tensor. + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `torch.Tensor`: + The denoised sample computed by combining the skip connection and output scaling. + """ sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) @@ -190,13 +248,13 @@ def precondition_outputs(self, sample, model_output, sigma): def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that + need to scale the denoising model input depending on the current timestep. Args: sample (`torch.Tensor`): - The input sample. - timestep (`int`, *optional*): + The input sample tensor. + timestep (`float` or `torch.Tensor`): The current timestep in the diffusion chain. Returns: @@ -214,19 +272,19 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T def set_timesteps( self, - num_inference_steps: int = None, - device: Union[str, torch.device] = None, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, sigmas: Optional[Union[torch.Tensor, List[float]]] = None, - ): + ) -> None: """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): + num_inference_steps (`int`, *optional*): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - sigmas (`Union[torch.Tensor, List[float]]`, *optional*): + sigmas (`torch.Tensor` or `List[float]`, *optional*): Custom sigmas to use for the denoising process. If not defined, the default behavior when `num_inference_steps` is passed will be used. """ @@ -262,8 +320,27 @@ def set_timesteps( self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 - def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Constructs the noise schedule of Karras et al. (2022).""" + def _compute_karras_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364). + + Args: + ramp (`torch.Tensor`): + A tensor of values in [0, 1] representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed Karras sigma schedule. + """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -273,10 +350,27 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch. sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: - """Implementation closely follows k-diffusion. - + def _compute_exponential_sigmas( + self, + ramp: torch.Tensor, + sigma_min: Optional[float] = None, + sigma_max: Optional[float] = None, + ) -> torch.Tensor: + """ + Compute the exponential sigma schedule. Implementation closely follows k-diffusion: https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + + Args: + ramp (`torch.Tensor`): + A tensor of values representing the interpolation positions. + sigma_min (`float`, *optional*): + Minimum sigma value. If `None`, uses `self.config.sigma_min`. + sigma_max (`float`, *optional*): + Maximum sigma value. If `None`, uses `self.config.sigma_max`. + + Returns: + `torch.Tensor`: + The computed exponential sigma schedule. """ sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max @@ -342,32 +436,38 @@ def step( generator: Optional[torch.Generator] = None, return_dict: bool = True, pred_original_sample: Optional[torch.Tensor] = None, - ) -> Union[EDMEulerSchedulerOutput, Tuple]: + ) -> Union[EDMEulerSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.Tensor`): - The direct output from learned diffusion model. - timestep (`float`): + The direct output from the learned diffusion model. + timestep (`float` or `torch.Tensor`): The current discrete timestep in the diffusion chain. sample (`torch.Tensor`): A current instance of a sample created by the diffusion process. - s_churn (`float`): - s_tmin (`float`): - s_tmax (`float`): - s_noise (`float`, defaults to 1.0): + s_churn (`float`, *optional*, defaults to `0.0`): + The amount of stochasticity to add at each step. Higher values add more noise. + s_tmin (`float`, *optional*, defaults to `0.0`): + The minimum sigma threshold below which no noise is added. + s_tmax (`float`, *optional*, defaults to `float("inf")`): + The maximum sigma threshold above which no noise is added. + s_noise (`float`, *optional*, defaults to `1.0`): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple. + A random number generator for reproducibility. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or tuple. + pred_original_sample (`torch.Tensor`, *optional*): + The predicted denoised sample from a previous step. If provided, skips recomputation. Returns: - [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is - returned, otherwise a tuple is returned where the first element is the sample tensor. + [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or `tuple`: + If `return_dict` is `True`, an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the previous sample tensor and the + second element is the predicted original sample tensor. """ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)): @@ -399,7 +499,10 @@ def step( if gamma > 0: noise = randn_tensor( - model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator + model_output.shape, + dtype=model_output.dtype, + device=model_output.device, + generator=generator, ) eps = noise * s_noise sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 @@ -478,9 +581,20 @@ def add_noise( noisy_samples = original_samples + noise * sigma return noisy_samples - def _get_conditioning_c_in(self, sigma): + def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]: + """ + Compute the input conditioning factor for the EDM formulation. + + Args: + sigma (`float` or `torch.Tensor`): + The current sigma (noise level) value. + + Returns: + `float` or `torch.Tensor`: + The input conditioning factor `c_in`. + """ c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) return c_in - def __len__(self): + def __len__(self) -> int: return self.config.num_train_timesteps