Code for this blog post:

NotebookGithub LinkColab
Predicting Error and Score FunctionError / Score PredictionColab (Large)
Classifier free Guidance and other improvementsAdvanced conceptsColab (Large)

Topics to cover

We have done most of the heavy-lifting in Part 1 of this series on Diffusion Models. To be able to use them well in practice, we may need to make some more improvements. That’s what we will do.

  1. Time step embedding and concatenation/fusion to the input data.
  2. Error Prediction $\hat \epsilon_0^\ast$ and Score Function Prediction $s$ instead of predicting the actual input $x_0$.
  3. Class conditioned generation or Classifier free guidance, where we guide the diffusion model to generate data based on class labels.

The ideas are an extension to the concepts introduced earlier, but are vital parts of any practical diffusion model implementation.

Time Step Embedding

During the denoising process, the Neural Network needs to know the time step at which denoising is being done. Passing the time step $t$ as a scalar value is not ideal. Rather, it would be preferable to pass the time step as an embedding to the Neural Network. In the Attention is all you need1 paper, the authors proposed sinusoidal embeddings to encode position of the tokens (time steps in our case).

Pseudocode:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import math
# adapted from HuggingFace -- https://huggingface.co/blog/annotated-diffusion

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        half_dim = self.dim // 2
        # embedding values need to be small
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim) * -embeddings)
        
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings
      
# generate 8 dimensional time step embeddings
sinusoidalPositionEmbeddings = SinusoidalPositionEmbeddings(8)

Sinusoidal embeddings are small in absolute value and can be fused(added) or concatenated to the input data to provide the Neural Network some information about the time step at which the denoising process is happening.

Passing Time Step Embedding (Fusing and Concatenation):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
# MLP to project time step embedding before we pass it to the input
self.position_mlp = nn.Sequential(
          nn.GELU(), 
          nn.Linear(8, 2) # 2 is the input data dimension
        )

timestep_embeddings = position_embeddings[timestep.long()]
time_embeddings = self.position_mlp(timestep_embeddings)

# concatenation
# x = (batch_size, input data dimension); time_embeddings = (batch_size, input data dimension)
concat_x = torch.cat((x, time_embeddings), dim=1)

# fusing
shift, scale = jnp.split(time_embedding, indices_or_sections=2, axis=-1)
x = shift + (scale+1)*x

Fusing the data with the data is computationally efficient as it requires significantly fewer number of weight parameters. The key idea here is that Neural Networks are powerful enough to separate any added information to the original data without the information being explicitly passed. Have a look at this video if you want to get better intuition. Positional embeddings in transformers EXPLAINED.

If we want to add time step information to an image, it is typically done by broadcasting the time step embedding along the channel dimension.

Adding Time Step to Image (Fusing):

1
2
3
4
5
6
7
8
9
# image batch data -- (batch_size, height, width, channels)
# timestep embedding -- (batch_size, embedding_dimension)

# ensure channels == embedding_dimension
# broadcast and add timestep to image data
import einops
timestep_embedding = einops.rearrange(b, 'b c -> b 1 1 c')

fused_image_data = image_data.add(timestep_embedding) 

Fusing data to an input is a technique useful to add any kind of conditional information to the data. We will use this idea in EMNIST: Blog 3 where we use this concept to fuse label and time step information with the images during class conditional generation. This technique was used in the work, Improved Denoising Diffusion Probabilistic Models.

Error Prediction and Score Function Prediction

Predicting the Error $\hat \epsilon_0^\ast$

During the denoising step, the model ingests $\hat x_t$ and spits out $\hat x_t^0$, a prediction for $x_0$. Instead of outputting predictions of the input data, we want the model to output a prediction of the error $\hat\epsilon^\ast_0$.

Let’s investigate the relationship between error $\epsilon_0^\ast$ and input data $x_0$. $$ \begin{align} x_t &= \sqrt{\bar\alpha_t}x_0 + \sqrt{(1 - \bar\alpha_t )}\ast\epsilon_0^\ast \cr x_0 &= \frac{x_t-\sqrt{(1 - \bar\alpha_t )}\ast\epsilon_0^\ast}{\sqrt{\bar\alpha_t}} \end{align} $$ Equation 2 shows that if we have $x_t$ and $\epsilon^\ast_0$ than it determines $x_0$.

The denoising step now looks like this:

  1. Make a prediction of source error $\hat\epsilon^\ast_0$ using the $NN(\hat x_t, t)$.
  2. Using Equation 2, evaluate $x_0^t$, which is the prediction of the input data at time step $t$.
  3. Next step:
    • During training: We want an equivalent version of the loss function $Error\_Loss(\epsilon^\ast_0, \hat \epsilon_0^\ast, t)$. $$ \begin{align} Loss(x_0, \hat x_t, t) = 1/2\ast(\frac{\bar\alpha_{t-1}}{1-\bar\alpha_{t-1}} - \frac{\bar\alpha_t}{1-\bar\alpha_t})\ast\mid\mid x_0-\hat x_0^t\mid\mid_2^2 \cr \text{Making modifications to loss using Equation 2} \nonumber \cr Error\_Loss(\epsilon^\ast_0, \hat \epsilon_0^\ast, t) = \frac{1}{2\sigma^2_q(t)} \frac{(1-\alpha_t)^2}{(1-\bar\alpha_t)\alpha_t}\mid\mid \epsilon^\ast_0-\hat \epsilon^\ast_0\mid\mid_2^2 \end{align} $$
    • During data generation:
      1. Using Equation 2, reconstruct $\hat x_0^t$. Remember, $x_0$ is a function of the Gaussian error and the latent variable.
      2. Clip the $\hat x_0^t$ to make sure it lies in the range of -1 to +1 (normalized range for input data). torch.clip(x_reconstructed, -1, 1)
      3. Using Equation 5 (below) get a prediction for the latent at $t-1$ time step.
$$ \begin{align} p(\hat x_{t-1}| \hat x_t) \varpropto N(\hat x_{t-1};\frac{\sqrt\alpha_t(1-\bar\alpha_{t-1})\hat x_t + \sqrt{\bar\alpha_{t-1}}(1-\alpha_t)\hat x_t^0}{1-\bar\alpha_t}, \frac{(1-\alpha_t)(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}I) \end{align} $$

Predicting error is empirically shown to work well, refer Denoising paper2. This is probably due to the clipping of the predicted output at every time step so that the predicted data is in the normalized range. Refer this discussion.

Here is an interesting discussion on why predicting error works for images, but may not work well for other domains such as voice generation.

Denoising Step during data generation:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
x_reconstructed = 
# data_in_batch is latent variable x at timestep
data_in_batch.T.sub(pred_data.mul(sd[timestep])).div(torch.sqrt(alphas_[timestep]))
if timestep >= 5:
	x_reconstructed = torch.clip(x_reconstructed, -1., 1.)

mean_data_1 = data_in_batch.T.mul(mean_coeff_1[timestep])
mean_data_2 = x_reconstructed.mul(mean_coeff_2[timestep])
mean_data = mean_data_1.add(mean_data_2)
	
posterior_data = posterior_variance_corrected[timestep]

# data_in_batch is latent variable x at previous timestep
data_in_batch = torch.normal(mean_data, torch.sqrt(posterior_data)).T

Score Function Prediction

This is yet another variation on diffusion models. Similar to predicting the error as discussed earlier, we could also predict the score function. The score function $s$ is defined as $\nabla p(x_t)$.

All we need is to define $x_t$ as $f(s)$. We could then define a $Score\_Loss(s, \hat s, t)$ by substituting $x_t$ for the function $f(s)$ in the training step. In the denoising step we will use the function $f(s)$ defined to compute $x_t$, perform clipping so that the output lies in the normalized range and proceed as we did in the earlier section.

The relation between $s$ and $x$ is defined as below: $$ x_0 = \frac{x_t + (1 - \bar\alpha_t )\ast\nabla p(x_t)}{\sqrt{\bar\alpha_t}} $$

For more details and intuition on why this is an important interpretation, please refer to the section on Three equivalent interpretations.3 Yang Song has an excellent blog post on Score based generative models.

Guidance & Classifier-free Guidance

We want to control the data we generate. For example, in the image below, we have 2 labels. During generation, we want to direct the model to generate samples either from Yellow or Purple classes. Classifier Guidance is a way to do this. We will be guiding the denoising process to generate samples that are more likely to belong to the conditioned class.

Figure 1: Data with 2 labels, Circles in Purple and Moons in Yellow

Figure 1: Data with 2 labels, Circles in Purple and Moons in Yellow

Text-guided diffusion models like Glide use the powerful Neural Network called CLIP and classifier guidance techniques to perform text-based image generation.

So far, we have been trying to maximize the likelihood of the data distribution $p(x)$ with diffusion models. This allowed us to randomly sample data points from the data distribution.

A naïve idea to do class-conditioned generation could be to fuse the conditional label information with the input data.

Pseudocode for naïve idea:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def diffusion(x_0, t):
	code to add noise to x_0
	return x_i

def training():
	for loop until convergence:
		pick an image x_0 from X (batch of images)
		sample t from 1 to T
		x_t = diffusion_step(x_0, t)
		x_hat_t = NN(x_t, t, y)  # y is the label
		loss_ = loss(x_0, x_hat_t)
		update(NN, grad(loss_)

# generating new data points through denoising steps
def generate_new_data():
	sample x_T from N(0, I)
	for t in range(T, 1):
		# denoising step
		x_hat_t = NN(x_t, t, y)  # y is the desired label 
		# get x_{t-1}
		x_{t-1} = func(x_hat_t, t) # refer EQ - 12
	x_hat_0 = x_0  

During training and data generation, we will add the conditioned label as an input along with the noisy data and the time step embedding. The output of the model will stay the same as earlier.

The conditioned data can be fused with the input data, just like we fused the time step information.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# haiku code

self.mlp = hk.Sequential([
  hk.Linear(256),
  jax.nn.gelu,
  hk.Linear(256),
])
# conditional vectors encoding
self.embedding_vectors = hk.Embed(10+26+26+1, 64)
self.timestep_embeddings = TimeEmbeddings(64)

# the diffusion model is given x, time step and label as input
def __call__(self, x, timesteps, cond=None):
    cond_embedding = None
    conditioning = None
    if timesteps is not None:
      timestep_embeddings = self.timestep_embeddings(timesteps)
      conditioning = timestep_embeddings
    if cond is not None:
      label_embeddings = self.embedding_vectors(cond)
      conditioning = jnp.concatenate([label_embeddings, conditioning], axis=1)
	cond_embedding = self.mlp(conditioning)
	...

	# fusing time step and label information with input x
	shift, scale = jnp.split(cond_embedding, indices_or_sections=2, axis=-1)
	x = shift + (scale+1)*x
	...

Guidance

The above approach may lead to models that have low sample diversity. Researchers have proposed two other forms of guidance: classifier guidance and classifier-free guidance4.

Classifier Guidance guides the generation of new samples with the help of a classifier. The classifier takes a noisy image ($x_t$) as input and predicts the label $y$. The gradient of the distribution $p(y|x)$ is used to make updates to the weights of the Neural Network to guide it to produce samples that are likely to be $y$. This is an adversarial loss, and this approach has similarities to GANs.

Classifier-Free Guidance would be ideal if we did not want to build a classifier. The classifier-free guidance approach models the conditional likelihood of samples as follows: $$ \nabla p(x|y) = \lambda \ast \underbrace{\nabla p(x|y)}_{\text{conditional}} + (1-\lambda) \ast \underbrace{\nabla p(x)}_{\text{unconditional}} $$ The conditional and the unconditional distributions are modelled by the same neural network. To model the conditional distribution, we fuse the label information as shown in the naïve approach above. To model the unconditional distribution, we mask the label information and pass it to the diffusion model. The lambda parameter, controls the diversity of the sample we want to generate. $\lambda=1$ would be equivalent to the naïve approach.

Adding this perspective about efficiency of diffusion models. Even though to be honest, I am not sure myself :).

Let’s look at some outputs

Figure 2: A GIF show-casing the denoising process; Generating class conditioned samples over T time steps

Figure 2: A GIF show-casing the denoising process; Generating class conditioned samples over T time steps

See you in the next part.


Want to connect? Reach out @varuntul22.