Browse Source

feat: training + sampling

master
Ramon Calvo 1 year ago
parent
commit
b0ca7b950a
  1. 31
      README.md
  2. BIN
      media/cifar-10-predicted.png
  3. 4
      src/simple_ddpm/model.py
  4. 39
      src/simple_ddpm/sample.py
  5. 23
      src/simple_ddpm/train.py

31
README.md

@ -0,0 +1,31 @@
# Simple DDPM
This is a bare bones and simple DDPM ([Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)) implementation on PyTorch. The whole implementation (model + training + sampling) does not exceed 400 lines of code. The training setup and U-Net model loosely resemble the description of the original paper, but it is not a 1 to 1 implementation.
![Predictions on CIFAR-10](./media/cifar-10-predicted.png)
These images were generated after training on CIFAR-10 for 256 epochs on a single RTX 4090.
# Usage
## Installation
It is recommended (but not required) to use [uv](https://github.com/astral-sh/uv) to replicate the Python environment:
```bash
uv sync # If using uv
python -m pip install . # Otherwise
```
## Training
```bash
uv run src/simple_ddpm/train.py # If using uv
python src/simple_ddpm/train.py # Otherwise
```
## Training
```bash
uv run src/simple_ddpm/sample.py # If using uv
python src/simple_ddpm/sample.py # Otherwise
```

BIN
media/cifar-10-predicted.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

4
src/simple_ddpm/model.py

@ -6,7 +6,7 @@ import math
T_EMBEDDING_SIZE = 32 T_EMBEDDING_SIZE = 32
def sinusoidal_positional_embedding(max_seq_len, d_model): def sinusoidal_positional_embedding(max_seq_len: int, d_model: int) -> torch.Tensor:
pe = torch.zeros(max_seq_len, d_model) pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp( div_term = torch.exp(
@ -149,7 +149,7 @@ class UNet(nn.Module):
intermediates = [] intermediates = []
for enc, dow in zip(self.encoder, self.downsamplers): for enc, dow in zip(self.encoder, self.downsamplers):
x = enc(x, t_emb) x = enc(x, t_emb)
intermediates.append(x) # TODO: Is this a copy? intermediates.append(x)
x = dow(x) x = dow(x)
x = self.bottleneck(x, t_emb) x = self.bottleneck(x, t_emb)

39
src/simple_ddpm/sample.py

@ -1,3 +1,4 @@
from typing import Dict, Union
import torch import torch
from tqdm import tqdm from tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -6,12 +7,14 @@ from train import extract, get_diffusion_params
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS from train import TIMESTEPS, IMAGE_SIZE, CHANNELS
from model import UNet from model import UNet
torch.manual_seed(42) torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@torch.no_grad() @torch.no_grad()
def p_sample(model, x, t, params): def p_sample(
model, x: torch.Tensor, t: torch.Tensor, params: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""Sample from the model at timestep t""" """Sample from the model at timestep t"""
predicted_noise = model(x, t) predicted_noise = model(x, t)
@ -30,7 +33,14 @@ def p_sample(model, x, t, params):
@torch.no_grad() @torch.no_grad()
def sample_images(model, image_size, batch_size, channels, device, params): def sample_images(
model: torch.nn.Module,
image_size: int,
batch_size: int,
channels: int,
device: torch.device,
params: Dict[str, torch.Tensor],
):
"""Generate new images using the trained model""" """Generate new images using the trained model"""
x = torch.randn(batch_size, channels, image_size, image_size).to(device) x = torch.randn(batch_size, channels, image_size, image_size).to(device)
@ -46,7 +56,7 @@ def sample_images(model, image_size, batch_size, channels, device, params):
return x return x
def show_images(images, title=""): def show_images(images: Union[torch.Tensor, np.array], title=""):
"""Display a batch of images in a grid""" """Display a batch of images in a grid"""
if isinstance(images, torch.Tensor): if isinstance(images, torch.Tensor):
images = images.detach().cpu().numpy() images = images.detach().cpu().numpy()
@ -61,21 +71,24 @@ def show_images(images, title=""):
plt.pause(0.001) plt.pause(0.001)
plt.figure(figsize=(10, 10)) if __name__ == "__main__":
plt.figure(figsize=(10, 10))
params = get_diffusion_params(TIMESTEPS, device) params = get_diffusion_params(TIMESTEPS, device)
model = UNet(32, TIMESTEPS).to(device) model = UNet(32, TIMESTEPS).to(device)
model.load_state_dict(torch.load("model.pkl", weights_only=True)) model.load_state_dict(torch.load("model.pkl", weights_only=True))
model.eval() model.eval()
generated_images = sample_images( generated_images = sample_images(
model=model, model=model,
image_size=IMAGE_SIZE, image_size=IMAGE_SIZE,
batch_size=16, batch_size=16,
channels=CHANNELS, channels=CHANNELS,
device=device, device=device,
params=params, params=params,
) )
show_images(generated_images, title="Generated Images") show_images(generated_images, title="Generated Images")
plt.show()
# Keep the plot open after generation is finished
plt.show()

23
src/simple_ddpm/train.py

@ -1,3 +1,4 @@
from typing import Dict, Callable
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -20,7 +21,6 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Data loading and preprocessing
transform = transforms.Compose( transform = transforms.Compose(
[ [
transforms.Resize(IMAGE_SIZE), transforms.Resize(IMAGE_SIZE),
@ -30,7 +30,6 @@ transform = transforms.Compose(
] ]
) )
# Download and load CIFAR-10
train_dataset = datasets.CIFAR10( train_dataset = datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform root="./data", train=True, download=True, transform=transform
) )
@ -40,13 +39,14 @@ train_loader = DataLoader(
) )
def linear_beta_schedule(timesteps): def get_diffusion_params(
timesteps: int, device: torch.device
) -> Dict[str, torch.Tensor]:
def linear_beta_schedule(timesteps):
beta_start = 0.0001 beta_start = 0.0001
beta_end = 0.02 beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps) return torch.linspace(beta_start, beta_end, timesteps)
def get_diffusion_params(timesteps, device):
betas = linear_beta_schedule(timesteps) betas = linear_beta_schedule(timesteps)
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
@ -68,14 +68,14 @@ def get_diffusion_params(timesteps, device):
} }
def extract(a, t, x_shape): def extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Tensor.shape):
"""Extract coefficients at specified timesteps t""" """Extract coefficients at specified timesteps t"""
batch_size = t.shape[0] batch_size = t.shape[0]
out = a.gather(-1, t) out = a.gather(-1, t)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def get_loss_fn(model, params): def get_loss_fn(model: torch.nn.Module, params: Dict[str, torch.Tensor]) -> Callable:
def loss_fn(x_0): def loss_fn(x_0):
batch_size = x_0.shape[0] batch_size = x_0.shape[0]
t = torch.randint(0, TIMESTEPS, (batch_size,), device=device) t = torch.randint(0, TIMESTEPS, (batch_size,), device=device)
@ -95,7 +95,9 @@ def get_loss_fn(model, params):
# Training loop template # Training loop template
def train_epoch(model, optimizer, train_loader, loss_fn): def train_epoch(
model: torch.nn.Module, optimize, train_loader: DataLoader, loss_fn: Callable
) -> float:
model.train() model.train()
total_loss = 0 total_loss = 0
@ -118,9 +120,14 @@ if __name__ == "__main__":
model = UNet(32, TIMESTEPS).to(device) model = UNet(32, TIMESTEPS).to(device)
nb_params = count_parameters(model) nb_params = count_parameters(model)
print(f"Total number of parameters: {nb_params}") print(f"Total number of parameters: {nb_params}")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95)) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95))
params = get_diffusion_params(TIMESTEPS, device) params = get_diffusion_params(TIMESTEPS, device)
loss_fn = get_loss_fn(model, params) loss_fn = get_loss_fn(model, params)
# Main training loop
for e in tqdm(range(NUM_EPOCHS)): for e in tqdm(range(NUM_EPOCHS)):
train_epoch(model, optimizer, train_loader, loss_fn) train_epoch(model, optimizer, train_loader, loss_fn)
# Save model after training
torch.save(model.state_dict(), "model.pkl") torch.save(model.state_dict(), "model.pkl")

Loading…
Cancel
Save