Deep learning notes 10: Diffusion models - noise to nice in few steps
This post if from a series of quick notes written primarily for personal usage while reading random ML/SWE/CS papers. As such they might be incomprehensible and/or flat out wrong.
DDPM - Diffusion Models Beat GANs on Image Synthesis
- Input + sampled little bit of noise; repeated multiple times (~1000s) -> pure noise
x_0 = x
;noise(x_t+1|x_t)
; process input image from data distr.:x
, appliednoise(…)
multiple times -> image of noise
- If we could invert this process -> generative model: random normal noise image -> original image
- Learn to undo one “little bit of noise” step at a time: distribution
noise(x_t-1|x_t)
- Sample random noise image, undo noise 1000s times (each time get one step cleaner image) -> sample clean data distr.
- Reversal gives us process of normal noise to data distribution
- Learn to undo one “little bit of noise” step at a time: distribution
- Noising up:
q
:noise(x_t+1|x_t)
well-defined process of adding noise from Normal distribution- Each step depends only on output of previous step
- Added noise has diagonal covariance matrix, is centered at last sample but down-scaled
- Given large
T
and well behaved schedule, last step is nearly isotropic Gaussian distribution - Produces vast amount of data pairs of
x_t-1
,x_t
- Denoising:
p
:noise(x_t-1|x_t)
requires entire data distribution -> approximated via neural network- Reversal doesn’t predict single image but a whole distribution of images (that could’ve been previous step)
- The output distribution is assumed to be gaussian (mean, covariance)
- The gaussian distribution assumption is maintained for small noise-up steps
- Combination of
p
andq
is ~VAE (variational auto-encoder) -> just train it- The true distribution can be easily computed out of the both known training pairs
- Loss forces the denoising network predicted distribution to be close do the true distribution
- The predicted covariance can be either statically set (actually doesn’t work super-bad) or also predicted
- If fixed: can be fixed based on the forward noise-up step parameters
- Combination of two loss functions are used, for stability also resampled (early noise-up steps are more impactful)
- Simple objective: L2 difference between true and predicted picture / noise
- Variational loss: proper KL divergence VAE loss, including variance, …
- Class-label guided generation can improve performance
- Train class classifier not only for clean images, but also for noised images -> use them to steer the generation
- Analogous to ~GANs; ~shifts the predicted distribution of the step-denoised-images to where specified label is likelier
- Idea: Have GANs that have multiple discriminators along the path from noise to final image ~ merge these two approaches
Autoregressive Diffusion Models
- New type of auto-regressive models: variables can be decoded in arbitrary order
- Autoregressive models: produces tokens (words/patches/pixels) one after another
- E.g. First word, based on priming with it a second word, based on first two a third, …
- Usually a fixed order, for sentences starts with a first one, …
- Repeat until the whole output has been generated
- ARDMs: Don’t have to go first to last, order could be arbitrary
- Can also produce multiple tokens at once reducing number of steps for lower accuracy
- At the beginning all tokens are initialized (?randomly/zero?)
- DNN (usually transformer) processes them -> per token output (e.g. distribution over categories)
- A portion of them are sampled and decoded (argmax for categorization) -> concrete outputs for few tokens
- Concrete outputs replace random inputs for the sampled tokens, DNN, new subset of tokens are decoded, …
- Repeat until all tokens are sampled & set
- ~Similar to BERT
- Trained with random word within sentence masking -> predicts distribution over words for masked tokens
- Training is similar to BERT, just with non-fixed blanked tokens ratio
- During training: mask a portion of tokens, average losses for all at once
- Sampling one timestep of one ordering where we decode & loss all of the remaining/maked tokens
- Left to right allows taking only the next (one) tokens’s loss -> less noisy
- Why can’t we sample all tokens at once?
- Tokens aren’t independent -> argmax on one token (collapsing its distribution) influences other tokens’ disr.
- Sampling multiple at once is faster (less steps necessary) but possibly less ideal outputs
- Extensions:
- Tokens could be re-sampled
- Multiple pixels at a time can be sampled -> to get the order/token groups dynamic programming
- Initially sample only more rough values (e.g. out of few colors), only later revisit & predict specific color
Written
by Petr Houška
on