The Diffusion Revolution

Diffusion models have revolutionized generative AI, powering systems like Stable Diffusion, DALL-E 2, and Midjourney. Unlike GANs which learn to generate directly, diffusion models learn to gradually denoise data, resulting in more stable training and higher quality outputs.

This guide explores the mathematical intuition and practical implementation of diffusion models, making them accessible to engineers and researchers.

The Core Idea

Two-Phase Process

  • Forward Process: Gradually add Gaussian noise to data until it becomes pure noise
  • Reverse Process: Train a neural network to reverse the noise, recovering the original data

The key insight: if we can learn to denoise slightly noisy images, we can chain these denoising steps together to generate images from pure noise.

Forward Diffusion Process

The forward process adds noise at each timestep according to a variance schedule:

Python
import torch
import torch.nn as nn
import numpy as np

class DiffusionForward:
    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
        """Initialize forward diffusion process."""
        self.num_timesteps = num_timesteps
        
        # Linear beta schedule
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        
        # Pre-compute useful values
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = torch.cat([
            torch.tensor([1.0]), 
            self.alphas_cumprod[:-1]
        ])
        
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
    
    def q_sample(self, x_0, t, noise=None):
        """Add noise to x_0 at timestep t.
        
        q(x_t | x_0) = N(x_t; √ᾱ_t x_0, (1 - ᾱ_t)I)
        """
        if noise is None:
            noise = torch.randn_like(x_0)
        
        sqrt_alpha_prod = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_prod = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        
        return sqrt_alpha_prod * x_0 + sqrt_one_minus_alpha_prod * noise

U-Net Denoising Architecture

Most diffusion models use a U-Net architecture with attention layers to predict noise:

Python
class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
    def forward(self, t):
        """Sinusoidal time embedding."""
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb

class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.time_mlp = nn.Linear(time_dim, out_channels)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        
    def forward(self, x, t):
        h = self.conv1(x)
        h = self.norm1(h)
        h = h + self.time_mlp(t)[:, :, None, None]
        h = torch.relu(h)
        h = self.conv2(h)
        h = self.norm2(h)
        h = torch.relu(h)
        return h

class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, time_dim=256):
        super().__init__()
        self.time_embed = TimeEmbedding(time_dim)
        
        # Encoder
        self.enc1 = UNetBlock(in_channels, 64, time_dim)
        self.enc2 = UNetBlock(64, 128, time_dim)
        self.enc3 = UNetBlock(128, 256, time_dim)
        
        # Bottleneck
        self.bottleneck = UNetBlock(256, 512, time_dim)
        
        # Decoder
        self.dec3 = UNetBlock(512 + 256, 256, time_dim)
        self.dec2 = UNetBlock(256 + 128, 128, time_dim)
        self.dec1 = UNetBlock(128 + 64, 64, time_dim)
        
        # Output
        self.out = nn.Conv2d(64, in_channels, 1)
        
        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        
    def forward(self, x, t):
        # Time embedding
        t_emb = self.time_embed(t)
        
        # Encoder
        e1 = self.enc1(x, t_emb)
        e2 = self.enc2(self.pool(e1), t_emb)
        e3 = self.enc3(self.pool(e2), t_emb)
        
        # Bottleneck
        b = self.bottleneck(self.pool(e3), t_emb)
        
        # Decoder with skip connections
        d3 = self.dec3(torch.cat([self.upsample(b), e3], dim=1), t_emb)
        d2 = self.dec2(torch.cat([self.upsample(d3), e2], dim=1), t_emb)
        d1 = self.dec1(torch.cat([self.upsample(d2), e1], dim=1), t_emb)
        
        return self.out(d1)

Training Loop

Python
def train_diffusion(model, dataloader, num_epochs=100):
    """Train diffusion model."""
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    diffusion = DiffusionForward()
    
    model.train()
    for epoch in range(num_epochs):
        for batch_idx, (images, _) in enumerate(dataloader):
            images = images.cuda()
            batch_size = images.shape[0]
            
            # Sample random timesteps
            t = torch.randint(0, diffusion.num_timesteps, (batch_size,)).cuda()
            
            # Sample noise
            noise = torch.randn_like(images)
            
            # Add noise to images
            noisy_images = diffusion.q_sample(images, t, noise)
            
            # Predict noise
            predicted_noise = model(noisy_images, t)
            
            # Compute loss (simple MSE)
            loss = nn.MSELoss()(predicted_noise, noise)
            
            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")

Sampling (Reverse Process)

Python
@torch.no_grad()
def sample(model, diffusion, image_size, batch_size=1, num_steps=1000):
    """Generate images using reverse diffusion process."""
    model.eval()
    
    # Start from pure noise
    x = torch.randn(batch_size, 3, image_size, image_size).cuda()
    
    # Reverse diffusion
    for t in reversed(range(num_steps)):
        t_batch = torch.full((batch_size,), t, dtype=torch.long).cuda()
        
        # Predict noise
        predicted_noise = model(x, t_batch)
        
        # Get diffusion parameters
        alpha = diffusion.alphas[t]
        alpha_cumprod = diffusion.alphas_cumprod[t]
        beta = diffusion.betas[t]
        
        # Compute x_{t-1}
        if t > 0:
            noise = torch.randn_like(x)
        else:
            noise = torch.zeros_like(x)
        
        # DDPM sampling formula
        x = (1 / torch.sqrt(alpha)) * (
            x - ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * predicted_noise
        ) + torch.sqrt(beta) * noise
    
    return x

# Generate images
generated_images = sample(model, diffusion, image_size=64, batch_size=4)

Advanced Techniques

💡 Improvements

  • Classifier-free guidance: Better control over generation
  • DDIM sampling: Faster sampling with fewer steps
  • Latent diffusion: Work in compressed latent space (Stable Diffusion)
  • Conditional generation: Control with text, class labels, or images
  • Cosine schedule: Better noise scheduling than linear

Latent Diffusion (Stable Diffusion)

Stable Diffusion applies diffusion in the latent space of a VAE, dramatically reducing computational requirements:

Python
from diffusers import StableDiffusionPipeline

# Load Stable Diffusion
pipe = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1",
    torch_dtype=torch.float16
).to("cuda")

# Generate image from text
prompt = "A serene mountain landscape at sunset, digital art"
image = pipe(
    prompt,
    num_inference_steps=50,
    guidance_scale=7.5
).images[0]

image.save("generated.png")

Conclusion

Diffusion models represent a paradigm shift in generative AI. By learning to reverse a noise process, they achieve remarkable generation quality with stable training.

The mathematical elegance combined with practical effectiveness has made diffusion models the foundation of modern image generation systems.

Resources

  • Ho, J., et al. (2020). "Denoising Diffusion Probabilistic Models"
  • Rombach, R., et al. (2022). "High-Resolution Image Synthesis with Latent Diffusion Models"
  • Hugging Face Diffusers: https://github.com/huggingface/diffusers
  • Lilian Weng's Blog: https://lilianweng.github.io/posts/2021-07-11-diffusion-models/