馃悕馃悕馃悕
1
2import torch
3import numpy as np
4
5def default_diffusion_timesteps(inference_steps, training_steps=1000):
6 step_spacing = training_steps // inference_steps
7 return (torch.arange(inference_steps - 1, -1, -1) * step_spacing).round() + 1
8
9def linspace_timesteps(step_count, max_step=999, min_step=0, power=1):
10 steps = torch.linspace(1, 0, step_count, dtype=torch.float64).pow(power).mul(max_step-min_step).add(min_step).round().int()
11 return steps
12
13def default_variance_schedule(variance_range, training_steps=1000):
14 variance_start, variance_end = variance_range
15 return torch.linspace(variance_start**0.5, variance_end**0.5, training_steps)**2
16
17# mathematically equivalent to hf diffusers' default euler scheduler sigmas
18def default_sigmas(forward_variance_schedule, diffusion_timesteps):
19 inverse_variance_complement_cumprod = torch.cumprod(1/(1 - forward_variance_schedule), dim=0)
20 sqrt_inv_snr = (inverse_variance_complement_cumprod - 1) ** 0.5
21 return torch.from_numpy(np.concatenate([np.interp(diffusion_timesteps, np.arange(0, len(sqrt_inv_snr)), sqrt_inv_snr), [0.0]]))
22