Speed up your diffusion model training with Min-SNR
The diffusion process gradually denoises an initial noisy image until a clear, quality image is obtained. The initial noisy image is a random Gaussian noise. The denoising process may take from as little as 15 steps up to thousands of steps. Generally, the more steps we take, the higher quality image we get, however, with a very large number of steps, the quality improvement is negligible. The training is slow because of the iterative generation nature of the diffusion process.. Authors of the paper Efficient Diffusion Training via Min-SNR Weighting Strategy find that modifying the training objective by adopting multi-task learning significantly speeds up the model converging speed and archives by 3.4× times faster convergence compared to previous methods while improving the FID at the same time.
Slow Convergence
Regular way of denoising diffusion models suffer from slow convergence and, therefore, require many GPU hours for model training. The authors of Min-SNR investigated the training process and discovered some interesting results. Optimizing a denoising function for a specific noise level (training timestep) can harm other timesteps.
Figure 1. Different diffusion-denoising timesteps were clustered into 100-timesteps length clusters. Each cluster was directly fine-tuned only for the denoising levels present in the cluster. The loss was improved for the cluster itself and surrounded timestep, however, finetuning was harmful to steps further away. Source
Figure 1 presents the fine-tuning of the pre-trained diffusion model solely on chosen denoising levels. While the loss is improved for the fine-tuned and surrounding timesteps, it harms the model performance on other, further away timesteps. Researchers hypothesize that this is because different steps have different requirements. At each step of a diffusion process, the strength of the denoising varies. Different timesteps have conflicting optimization directions. Taking this into account, directly optimizing all diffusion steps with the same strength will result in taking many gradient steps in sub-optimal directions and thus significantly extend the training time and model convergence. As the solution, the multi-task learning approach is proposed.
Multi-Task Learning
Multi-task learning aims to directly optimize multiple different tasks. Several different approaches have been proposed, among others:
- Sharing the features extractor between tasks
- Learning what to share between tasks
- Fine-grained parameters sharing
- Per-task loss weighting to avoid conflicting gradients
The diffusion model already shares the common module between the tasks, which is the whole U-Net model. The natural direction would be to incorporate other techniques alongside the shared model. Multi-task learning with deep neural networks: a survey mentions per-task loss weighting as the most prominent direction for MTL, and authors of Min-SNR follow this direction.
Treating each denoising level as a separate task allows us to apply Multi-Task Learning to the diffusion model training. Some tasks in the diffusion process are harder than others. For example, initial denoising levels contain almost pure noise, and therefore the correct denoising direction is not obvious, while the last denoising steps already have a well-formed image and mainly reconstruct the input image. Therefore, by treating each noise level as a separate task, we can assign different importance levels to them and weight the gradient update. Jointly optimizing all tasks allows for proper weight selection. The weights are chosen in a way that updates the gradient at one level and does not harm other levels. The only thing left is how to find those optimal weights.
Min-SNR
Min-SNR uses a Signal-to-Noise level for each noise level to figure out the weight strength. Intuitively SNR measures how much useful information (signal) is present in the image.
SNR(t) = αt2 /σt2
xt = αtx0 + σtϵ
αt = sqrt(1 − σt2)
, where ϵ is a noise sampled from Gaussian distribution N(0, I) and σt denotes the magnitude of noise added to the clean data x0 at timestep t and αt can be learned by reparameterization or held constant as hyperparameters.
Figure 2. The upper plot presents a pre-computed SNR for 100 steps. After applying the formula min{SNR(t), γ} it can be used to weight the loss function of the diffusion model, which predicts signal x as an output. Recently, most diffusion models predict, however, added noise. The bottom plot presents pre-computed loss weight for the diffusion model predicting noise ϵ. The time t=100 is the initial step in the diffusion process, and t=0 is the final step.
It is worth noting that SNR can be computed beforehand for all timesteps. The SNR will be small for initial denoising steps and large for the least denoising levels. Intuitively it makes sense, as least denoising layers improve the image quality just slightly, mostly reconstructing the input, while initial denoising levels are very noisy without obvious improvement direction. SNR weighting ensures that those uncertain initial forward steps are updated slowly and the last step, for which the updated direction is more obvious, is updated with more strength.
Figure 3. Noise level at different timesteps.
The min-SNR = min{SNR(t), γ}. y = 5 stops SNR weight from having too large values. As proven by the paper authors and shown in Figure 4, the min-SNR speeds up the model convergence by x3.4 times, compared to the model trained, with each noisy step being equally likely and improving the FID score at the same time.
Because recently, most of the diffusion models predict noise instead of the signal, the snr weight should be inverted:
wt = min{SNR(t),γ} / SNR(t) = min{ γ / SNR(t) , 1}.
, where wt is the loss weight for the timestep t.
y = 5
loss_weight = np.stack([snr, y * np.ones_like(snr)], axis=1).min(axis=1) / snr
Figure 4. Comparison of convergence speed and FID score for different training strategies. The min-SNR, with y = 5 performs best in this setup. Source
Summary
In this blog post, we learned that regular diffusion training is slow because of conflicting gradient update directions for different noise levels, and updating one noise level can harm others. By treating each timestep as a separate training task we can apply multi-task learning methods and try to avoid gradient conflicts. Weighting the loss for each task can successfully balance harder and easier noise levels and ensure that one noise level does not harm others. The authors of the paper prove that min-SNR is an effective strategy for loss weighting. It speeds up convergence by 3.4x times while improving the FID simultaneously. If you would like to learn more about Stable Diffusion please check out our other blog posts about classification-free diffusion model guidance and ContolNet.
Reviewed by: Rafał Pytel