Browse Source

First commit

master
Ramon Calvo 1 year ago
commit
7e3247eefc
  1. 18
      pyproject.toml
  2. 125
      src/simple_ddpm/model.py
  3. 80
      src/simple_ddpm/sample.py
  4. 128
      src/simple_ddpm/train.py

18
pyproject.toml

@ -0,0 +1,18 @@
[project]
name = "simple-ddpm"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"torch>=2.5.0",
"torchvision>=0.20.0",
"matplotlib>=3.9.2",
"tqdm>=4.66.5",
"numpy>=2.1.2",
"pyqt6>=6.7.1",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

125
src/simple_ddpm/model.py

@ -0,0 +1,125 @@
import torch
import torch.nn as nn
import math
T_EMBEDDING_SIZE = 32
def sinusoidal_positional_embedding(max_seq_len, d_model):
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe
class Conv(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super(Conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=out_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=out_channels),
# nn.SiLU(),
nn.Dropout(0.1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
class Downsample(nn.Module):
def __init__(self, channels: int):
super(Downsample, self).__init__()
self.downsample = nn.Conv2d(
channels, channels, kernel_size=3, padding=1, stride=2
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.downsample(x)
class Upsample(nn.Module):
def __init__(self, channels: int):
super(Upsample, self).__init__()
self.upsample = nn.ConvTranspose2d(
channels, channels, kernel_size=2, padding=0, stride=2
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.upsample(x)
class UNet(nn.Module):
def __init__(
self,
image_size: int,
nb_timesteps: int,
in_channels: int = 3,
out_channels: int = 3,
):
super(UNet, self).__init__()
self.register_buffer(
"position_embeddings",
sinusoidal_positional_embedding(nb_timesteps, T_EMBEDDING_SIZE),
)
self.encoder = nn.ModuleList(
[
Conv(in_channels + T_EMBEDDING_SIZE, 64),
Conv(64, 128),
Conv(128, 256),
Conv(256, 512),
]
)
self.downsamplers = nn.ModuleList(
[Downsample(64), Downsample(128), Downsample(256), Downsample(512)]
)
self.bottleneck = Conv(512, 512)
self.decoder = nn.ModuleList(
[
Conv(2 * 512, 256),
Conv(2 * 256, 128),
Conv(2 * 128, 64),
Conv(2 * 64, out_channels),
]
)
self.upsamplers = nn.ModuleList(
[Upsample(512), Upsample(256), Upsample(128), Upsample(64)]
)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
B, H, W = x.shape[0], x.shape[2], x.shape[3]
t_embedding = self.position_embeddings[t] # (B, T_DIM)
t_embedding = t_embedding.view(B, T_EMBEDDING_SIZE, 1, 1).expand(
B, T_EMBEDDING_SIZE, H, W
)
x = torch.cat((t_embedding, x), axis=1)
intermediates = []
for enc, dow in zip(self.encoder, self.downsamplers):
x = enc(x)
intermediates.append(x) # TODO: Is this a copy?
x = dow(x)
x = self.bottleneck(x)
for dec, up, m in zip(self.decoder, self.upsamplers, reversed(intermediates)):
x = up(x)
x = torch.concat((x, m), axis=1) # Channel dimension
x = dec(x)
return x

80
src/simple_ddpm/sample.py

@ -0,0 +1,80 @@
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from train import extract, get_diffusion_params
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS
from model import UNet
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def p_sample(model, x, t, params):
"""Sample from the model at timestep t"""
predicted_noise = model(x, t)
one_over_alphas = extract(params["one_over_alphas"], t, x.shape)
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape)
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise)
posterior_variance = extract(params["posterior_variance"], t, x.shape)
if t[0] > 0:
noise = torch.randn_like(x)
return pred_mean + torch.sqrt(posterior_variance) * noise
else:
return pred_mean
@torch.no_grad()
def sample_images(model, image_size, batch_size, channels, device, params):
"""Generate new images using the trained model"""
model.eval()
# Start from pure noise
x = torch.randn(batch_size, channels, image_size, image_size).to(device)
# Gradually denoise the image
for t in tqdm(reversed(range(TIMESTEPS)), desc="Sampling", total=TIMESTEPS):
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
x = p_sample(model, x, t_batch, params)
if x.isnan().any():
raise ValueError(f"NaN detected in image at timestep {t}")
return x
def show_images(images, title=""):
"""Display a batch of images in a grid"""
if isinstance(images, torch.Tensor):
images = images.detach().cpu().numpy()
images = (images * 0.25 + 0.5).clip(0, 1)
plt.figure(figsize=(10, 10))
for idx in range(min(16, len(images))):
plt.subplot(4, 4, idx + 1)
plt.imshow(np.transpose(images[idx], (1, 2, 0)))
plt.axis("off")
plt.suptitle(title)
plt.show()
params = get_diffusion_params(TIMESTEPS, device)
model = UNet(32, TIMESTEPS).to(device)
model.load_state_dict(torch.load("model.pkl", weights_only=True))
model.eval()
generated_images = sample_images(
model=model,
image_size=IMAGE_SIZE,
batch_size=16,
channels=CHANNELS,
device=device,
params=params,
)
show_images(generated_images, title="Generated Images")

128
src/simple_ddpm/train.py

@ -0,0 +1,128 @@
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from model import UNet
# Set random seed for reproducibility
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
NUM_EPOCHS = 128
BATCH_SIZE = 128
IMAGE_SIZE = 32
CHANNELS = 3
TIMESTEPS = 1000
# Data loading and preprocessing
transform = transforms.Compose(
[
transforms.Resize(IMAGE_SIZE),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
]
)
# Download and load CIFAR-10
train_dataset = datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
train_loader = DataLoader(
train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
)
# Define beta schedule (linear schedule as an example)
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
# Calculate diffusion parameters
def get_diffusion_params(timesteps, device):
betas = linear_beta_schedule(timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
# Calculate posterior mean coefficients
one_over_alphas = 1.0 / torch.sqrt(alphas)
posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod
# Posterior variance
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
return {
"betas": betas.to(device),
"alphas_cumprod": alphas_cumprod.to(device),
"posterior_variance": posterior_variance.to(device),
"one_over_alphas": one_over_alphas.to(device),
"posterior_mean_coef": posterior_mean_coef.to(device),
}
# Utility functions
def extract(a, t, x_shape):
"""Extract coefficients at specified timesteps t"""
batch_size = t.shape[0]
out = a.gather(-1, t)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
# Training utilities
def get_loss_fn(model, params):
def loss_fn(x_0):
batch_size = x_0.shape[0]
t = torch.randint(0, TIMESTEPS, (batch_size,), device=device)
noise = torch.randn_like(x_0)
# Get the noisy image
alpha_cumprod = extract(params["alphas_cumprod"], t, x_0.shape)
noise_level = torch.sqrt(1.0 - alpha_cumprod)
x_noisy = torch.sqrt(alpha_cumprod) * x_0 + noise_level * noise
# Get predicted noise
predicted_noise = model(x_noisy, t)
return torch.nn.functional.mse_loss(predicted_noise, noise)
return loss_fn
# Training loop template
def train_epoch(model, optimizer, train_loader, loss_fn):
model.train()
total_loss = 0
with tqdm(train_loader, leave=False) as pbar:
for batch in pbar:
images = batch[0].to(device)
optimizer.zero_grad()
loss = loss_fn(images)
loss.backward()
optimizer.step()
total_loss += loss.item()
pbar.set_description(f"Loss: {loss.item():.4f}")
return total_loss / len(train_loader)
if __name__ == "__main__":
model = UNet(32, TIMESTEPS).to(device)
model_avg = torch.optim.swa_utils.AveragedModel(
model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.9999)
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95))
params = get_diffusion_params(TIMESTEPS, device)
loss_fn = get_loss_fn(model, params)
for e in tqdm(range(NUM_EPOCHS)):
train_epoch(model, optimizer, train_loader, loss_fn)
torch.save(model.state_dict(), "model.pkl")
Loading…
Cancel
Save