Browse Source

feat: training + sampling

master
Ramon Calvo 1 year ago
parent
commit
b0ca7b950a
  1. 31
      README.md
  2. BIN
      media/cifar-10-predicted.png
  3. 4
      src/simple_ddpm/model.py
  4. 51
      src/simple_ddpm/sample.py
  5. 29
      src/simple_ddpm/train.py

31
README.md

@ -0,0 +1,31 @@
# Simple DDPM
This is a bare bones and simple DDPM ([Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)) implementation on PyTorch. The whole implementation (model + training + sampling) does not exceed 400 lines of code. The training setup and U-Net model loosely resemble the description of the original paper, but it is not a 1 to 1 implementation.
![Predictions on CIFAR-10](./media/cifar-10-predicted.png)
These images were generated after training on CIFAR-10 for 256 epochs on a single RTX 4090.
# Usage
## Installation
It is recommended (but not required) to use [uv](https://github.com/astral-sh/uv) to replicate the Python environment:
```bash
uv sync # If using uv
python -m pip install . # Otherwise
```
## Training
```bash
uv run src/simple_ddpm/train.py # If using uv
python src/simple_ddpm/train.py # Otherwise
```
## Training
```bash
uv run src/simple_ddpm/sample.py # If using uv
python src/simple_ddpm/sample.py # Otherwise
```

BIN
media/cifar-10-predicted.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

4
src/simple_ddpm/model.py

@ -6,7 +6,7 @@ import math
T_EMBEDDING_SIZE = 32
def sinusoidal_positional_embedding(max_seq_len, d_model):
def sinusoidal_positional_embedding(max_seq_len: int, d_model: int) -> torch.Tensor:
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
@ -149,7 +149,7 @@ class UNet(nn.Module):
intermediates = []
for enc, dow in zip(self.encoder, self.downsamplers):
x = enc(x, t_emb)
intermediates.append(x) # TODO: Is this a copy?
intermediates.append(x)
x = dow(x)
x = self.bottleneck(x, t_emb)

51
src/simple_ddpm/sample.py

@ -1,3 +1,4 @@
from typing import Dict, Union
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
@ -6,12 +7,14 @@ from train import extract, get_diffusion_params
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS
from model import UNet
torch.manual_seed(42)
torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def p_sample(model, x, t, params):
def p_sample(
model, x: torch.Tensor, t: torch.Tensor, params: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""Sample from the model at timestep t"""
predicted_noise = model(x, t)
@ -30,7 +33,14 @@ def p_sample(model, x, t, params):
@torch.no_grad()
def sample_images(model, image_size, batch_size, channels, device, params):
def sample_images(
model: torch.nn.Module,
image_size: int,
batch_size: int,
channels: int,
device: torch.device,
params: Dict[str, torch.Tensor],
):
"""Generate new images using the trained model"""
x = torch.randn(batch_size, channels, image_size, image_size).to(device)
@ -46,7 +56,7 @@ def sample_images(model, image_size, batch_size, channels, device, params):
return x
def show_images(images, title=""):
def show_images(images: Union[torch.Tensor, np.array], title=""):
"""Display a batch of images in a grid"""
if isinstance(images, torch.Tensor):
images = images.detach().cpu().numpy()
@ -61,21 +71,24 @@ def show_images(images, title=""):
plt.pause(0.001)
plt.figure(figsize=(10, 10))
if __name__ == "__main__":
plt.figure(figsize=(10, 10))
params = get_diffusion_params(TIMESTEPS, device)
params = get_diffusion_params(TIMESTEPS, device)
model = UNet(32, TIMESTEPS).to(device)
model.load_state_dict(torch.load("model.pkl", weights_only=True))
model = UNet(32, TIMESTEPS).to(device)
model.load_state_dict(torch.load("model.pkl", weights_only=True))
model.eval()
generated_images = sample_images(
model=model,
image_size=IMAGE_SIZE,
batch_size=16,
channels=CHANNELS,
device=device,
params=params,
)
show_images(generated_images, title="Generated Images")
plt.show()
model.eval()
generated_images = sample_images(
model=model,
image_size=IMAGE_SIZE,
batch_size=16,
channels=CHANNELS,
device=device,
params=params,
)
show_images(generated_images, title="Generated Images")
# Keep the plot open after generation is finished
plt.show()

29
src/simple_ddpm/train.py

@ -1,3 +1,4 @@
from typing import Dict, Callable
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
@ -20,7 +21,6 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Data loading and preprocessing
transform = transforms.Compose(
[
transforms.Resize(IMAGE_SIZE),
@ -30,7 +30,6 @@ transform = transforms.Compose(
]
)
# Download and load CIFAR-10
train_dataset = datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
@ -40,13 +39,14 @@ train_loader = DataLoader(
)
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def get_diffusion_params(
timesteps: int, device: torch.device
) -> Dict[str, torch.Tensor]:
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def get_diffusion_params(timesteps, device):
betas = linear_beta_schedule(timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
@ -68,14 +68,14 @@ def get_diffusion_params(timesteps, device):
}
def extract(a, t, x_shape):
def extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Tensor.shape):
"""Extract coefficients at specified timesteps t"""
batch_size = t.shape[0]
out = a.gather(-1, t)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def get_loss_fn(model, params):
def get_loss_fn(model: torch.nn.Module, params: Dict[str, torch.Tensor]) -> Callable:
def loss_fn(x_0):
batch_size = x_0.shape[0]
t = torch.randint(0, TIMESTEPS, (batch_size,), device=device)
@ -95,7 +95,9 @@ def get_loss_fn(model, params):
# Training loop template
def train_epoch(model, optimizer, train_loader, loss_fn):
def train_epoch(
model: torch.nn.Module, optimize, train_loader: DataLoader, loss_fn: Callable
) -> float:
model.train()
total_loss = 0
@ -118,9 +120,14 @@ if __name__ == "__main__":
model = UNet(32, TIMESTEPS).to(device)
nb_params = count_parameters(model)
print(f"Total number of parameters: {nb_params}")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95))
params = get_diffusion_params(TIMESTEPS, device)
loss_fn = get_loss_fn(model, params)
# Main training loop
for e in tqdm(range(NUM_EPOCHS)):
train_epoch(model, optimizer, train_loader, loss_fn)
# Save model after training
torch.save(model.state_dict(), "model.pkl")

Loading…
Cancel
Save