The Diffusion Revolution
Diffusion models have revolutionized generative AI, powering systems like Stable Diffusion, DALL-E 2, and Midjourney. Unlike GANs which learn to generate directly, diffusion models learn to gradually denoise data, resulting in more stable training and higher quality outputs.
This guide explores the mathematical intuition and practical implementation of diffusion models, making them accessible to engineers and researchers.
The Core Idea
Two-Phase Process
- Forward Process: Gradually add Gaussian noise to data until it becomes pure noise
- Reverse Process: Train a neural network to reverse the noise, recovering the original data
The key insight: if we can learn to denoise slightly noisy images, we can chain these denoising steps together to generate images from pure noise.
Forward Diffusion Process
The forward process adds noise at each timestep according to a variance schedule:
import torch
import torch.nn as nn
import numpy as np
class DiffusionForward:
def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
"""Initialize forward diffusion process."""
self.num_timesteps = num_timesteps
# Linear beta schedule
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
# Pre-compute useful values
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.alphas_cumprod_prev = torch.cat([
torch.tensor([1.0]),
self.alphas_cumprod[:-1]
])
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
def q_sample(self, x_0, t, noise=None):
"""Add noise to x_0 at timestep t.
q(x_t | x_0) = N(x_t; √ᾱ_t x_0, (1 - ᾱ_t)I)
"""
if noise is None:
noise = torch.randn_like(x_0)
sqrt_alpha_prod = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
sqrt_one_minus_alpha_prod = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
return sqrt_alpha_prod * x_0 + sqrt_one_minus_alpha_prod * noise
U-Net Denoising Architecture
Most diffusion models use a U-Net architecture with attention layers to predict noise:
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, t):
"""Sinusoidal time embedding."""
half_dim = self.dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return emb
class UNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_dim):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.time_mlp = nn.Linear(time_dim, out_channels)
self.norm1 = nn.GroupNorm(8, out_channels)
self.norm2 = nn.GroupNorm(8, out_channels)
def forward(self, x, t):
h = self.conv1(x)
h = self.norm1(h)
h = h + self.time_mlp(t)[:, :, None, None]
h = torch.relu(h)
h = self.conv2(h)
h = self.norm2(h)
h = torch.relu(h)
return h
class SimpleUNet(nn.Module):
def __init__(self, in_channels=3, time_dim=256):
super().__init__()
self.time_embed = TimeEmbedding(time_dim)
# Encoder
self.enc1 = UNetBlock(in_channels, 64, time_dim)
self.enc2 = UNetBlock(64, 128, time_dim)
self.enc3 = UNetBlock(128, 256, time_dim)
# Bottleneck
self.bottleneck = UNetBlock(256, 512, time_dim)
# Decoder
self.dec3 = UNetBlock(512 + 256, 256, time_dim)
self.dec2 = UNetBlock(256 + 128, 128, time_dim)
self.dec1 = UNetBlock(128 + 64, 64, time_dim)
# Output
self.out = nn.Conv2d(64, in_channels, 1)
self.pool = nn.MaxPool2d(2)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
def forward(self, x, t):
# Time embedding
t_emb = self.time_embed(t)
# Encoder
e1 = self.enc1(x, t_emb)
e2 = self.enc2(self.pool(e1), t_emb)
e3 = self.enc3(self.pool(e2), t_emb)
# Bottleneck
b = self.bottleneck(self.pool(e3), t_emb)
# Decoder with skip connections
d3 = self.dec3(torch.cat([self.upsample(b), e3], dim=1), t_emb)
d2 = self.dec2(torch.cat([self.upsample(d3), e2], dim=1), t_emb)
d1 = self.dec1(torch.cat([self.upsample(d2), e1], dim=1), t_emb)
return self.out(d1)
Training Loop
def train_diffusion(model, dataloader, num_epochs=100):
"""Train diffusion model."""
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
diffusion = DiffusionForward()
model.train()
for epoch in range(num_epochs):
for batch_idx, (images, _) in enumerate(dataloader):
images = images.cuda()
batch_size = images.shape[0]
# Sample random timesteps
t = torch.randint(0, diffusion.num_timesteps, (batch_size,)).cuda()
# Sample noise
noise = torch.randn_like(images)
# Add noise to images
noisy_images = diffusion.q_sample(images, t, noise)
# Predict noise
predicted_noise = model(noisy_images, t)
# Compute loss (simple MSE)
loss = nn.MSELoss()(predicted_noise, noise)
# Backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
Sampling (Reverse Process)
@torch.no_grad()
def sample(model, diffusion, image_size, batch_size=1, num_steps=1000):
"""Generate images using reverse diffusion process."""
model.eval()
# Start from pure noise
x = torch.randn(batch_size, 3, image_size, image_size).cuda()
# Reverse diffusion
for t in reversed(range(num_steps)):
t_batch = torch.full((batch_size,), t, dtype=torch.long).cuda()
# Predict noise
predicted_noise = model(x, t_batch)
# Get diffusion parameters
alpha = diffusion.alphas[t]
alpha_cumprod = diffusion.alphas_cumprod[t]
beta = diffusion.betas[t]
# Compute x_{t-1}
if t > 0:
noise = torch.randn_like(x)
else:
noise = torch.zeros_like(x)
# DDPM sampling formula
x = (1 / torch.sqrt(alpha)) * (
x - ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * predicted_noise
) + torch.sqrt(beta) * noise
return x
# Generate images
generated_images = sample(model, diffusion, image_size=64, batch_size=4)
Advanced Techniques
💡 Improvements
- Classifier-free guidance: Better control over generation
- DDIM sampling: Faster sampling with fewer steps
- Latent diffusion: Work in compressed latent space (Stable Diffusion)
- Conditional generation: Control with text, class labels, or images
- Cosine schedule: Better noise scheduling than linear
Latent Diffusion (Stable Diffusion)
Stable Diffusion applies diffusion in the latent space of a VAE, dramatically reducing computational requirements:
from diffusers import StableDiffusionPipeline
# Load Stable Diffusion
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
torch_dtype=torch.float16
).to("cuda")
# Generate image from text
prompt = "A serene mountain landscape at sunset, digital art"
image = pipe(
prompt,
num_inference_steps=50,
guidance_scale=7.5
).images[0]
image.save("generated.png")
Conclusion
Diffusion models represent a paradigm shift in generative AI. By learning to reverse a noise process, they achieve remarkable generation quality with stable training.
The mathematical elegance combined with practical effectiveness has made diffusion models the foundation of modern image generation systems.
Resources
- Ho, J., et al. (2020). "Denoising Diffusion Probabilistic Models"
- Rombach, R., et al. (2022). "High-Resolution Image Synthesis with Latent Diffusion Models"
- Hugging Face Diffusers: https://github.com/huggingface/diffusers
- Lilian Weng's Blog: https://lilianweng.github.io/posts/2021-07-11-diffusion-models/