commit
7e3247eefc
4 changed files with 351 additions and 0 deletions
@ -0,0 +1,18 @@ |
|||
[project] |
|||
name = "simple-ddpm" |
|||
version = "0.1.0" |
|||
description = "Add your description here" |
|||
readme = "README.md" |
|||
requires-python = ">=3.12" |
|||
dependencies = [ |
|||
"torch>=2.5.0", |
|||
"torchvision>=0.20.0", |
|||
"matplotlib>=3.9.2", |
|||
"tqdm>=4.66.5", |
|||
"numpy>=2.1.2", |
|||
"pyqt6>=6.7.1", |
|||
] |
|||
|
|||
[build-system] |
|||
requires = ["hatchling"] |
|||
build-backend = "hatchling.build" |
|||
@ -0,0 +1,125 @@ |
|||
import torch |
|||
import torch.nn as nn |
|||
import math |
|||
|
|||
|
|||
T_EMBEDDING_SIZE = 32 |
|||
|
|||
|
|||
def sinusoidal_positional_embedding(max_seq_len, d_model): |
|||
pe = torch.zeros(max_seq_len, d_model) |
|||
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) |
|||
div_term = torch.exp( |
|||
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) |
|||
) |
|||
|
|||
pe[:, 0::2] = torch.sin(position * div_term) |
|||
pe[:, 1::2] = torch.cos(position * div_term) |
|||
|
|||
return pe |
|||
|
|||
|
|||
class Conv(nn.Module): |
|||
def __init__(self, in_channels: int, out_channels: int): |
|||
super(Conv, self).__init__() |
|||
|
|||
self.conv = nn.Sequential( |
|||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|||
nn.BatchNorm2d(num_features=out_channels), |
|||
nn.SiLU(), |
|||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
|||
nn.BatchNorm2d(num_features=out_channels), |
|||
# nn.SiLU(), |
|||
nn.Dropout(0.1), |
|||
) |
|||
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|||
return self.conv(x) |
|||
|
|||
|
|||
class Downsample(nn.Module): |
|||
def __init__(self, channels: int): |
|||
super(Downsample, self).__init__() |
|||
|
|||
self.downsample = nn.Conv2d( |
|||
channels, channels, kernel_size=3, padding=1, stride=2 |
|||
) |
|||
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|||
return self.downsample(x) |
|||
|
|||
|
|||
class Upsample(nn.Module): |
|||
def __init__(self, channels: int): |
|||
super(Upsample, self).__init__() |
|||
|
|||
self.upsample = nn.ConvTranspose2d( |
|||
channels, channels, kernel_size=2, padding=0, stride=2 |
|||
) |
|||
|
|||
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|||
return self.upsample(x) |
|||
|
|||
|
|||
class UNet(nn.Module): |
|||
def __init__( |
|||
self, |
|||
image_size: int, |
|||
nb_timesteps: int, |
|||
in_channels: int = 3, |
|||
out_channels: int = 3, |
|||
): |
|||
super(UNet, self).__init__() |
|||
|
|||
self.register_buffer( |
|||
"position_embeddings", |
|||
sinusoidal_positional_embedding(nb_timesteps, T_EMBEDDING_SIZE), |
|||
) |
|||
|
|||
self.encoder = nn.ModuleList( |
|||
[ |
|||
Conv(in_channels + T_EMBEDDING_SIZE, 64), |
|||
Conv(64, 128), |
|||
Conv(128, 256), |
|||
Conv(256, 512), |
|||
] |
|||
) |
|||
self.downsamplers = nn.ModuleList( |
|||
[Downsample(64), Downsample(128), Downsample(256), Downsample(512)] |
|||
) |
|||
self.bottleneck = Conv(512, 512) |
|||
|
|||
self.decoder = nn.ModuleList( |
|||
[ |
|||
Conv(2 * 512, 256), |
|||
Conv(2 * 256, 128), |
|||
Conv(2 * 128, 64), |
|||
Conv(2 * 64, out_channels), |
|||
] |
|||
) |
|||
self.upsamplers = nn.ModuleList( |
|||
[Upsample(512), Upsample(256), Upsample(128), Upsample(64)] |
|||
) |
|||
|
|||
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|||
B, H, W = x.shape[0], x.shape[2], x.shape[3] |
|||
t_embedding = self.position_embeddings[t] # (B, T_DIM) |
|||
t_embedding = t_embedding.view(B, T_EMBEDDING_SIZE, 1, 1).expand( |
|||
B, T_EMBEDDING_SIZE, H, W |
|||
) |
|||
x = torch.cat((t_embedding, x), axis=1) |
|||
|
|||
intermediates = [] |
|||
for enc, dow in zip(self.encoder, self.downsamplers): |
|||
x = enc(x) |
|||
intermediates.append(x) # TODO: Is this a copy? |
|||
x = dow(x) |
|||
|
|||
x = self.bottleneck(x) |
|||
|
|||
for dec, up, m in zip(self.decoder, self.upsamplers, reversed(intermediates)): |
|||
x = up(x) |
|||
x = torch.concat((x, m), axis=1) # Channel dimension |
|||
x = dec(x) |
|||
|
|||
return x |
|||
@ -0,0 +1,80 @@ |
|||
import torch |
|||
from tqdm import tqdm |
|||
import matplotlib.pyplot as plt |
|||
import numpy as np |
|||
from train import extract, get_diffusion_params |
|||
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS |
|||
from model import UNet |
|||
|
|||
torch.manual_seed(42) |
|||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def p_sample(model, x, t, params): |
|||
"""Sample from the model at timestep t""" |
|||
predicted_noise = model(x, t) |
|||
|
|||
one_over_alphas = extract(params["one_over_alphas"], t, x.shape) |
|||
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) |
|||
|
|||
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) |
|||
|
|||
posterior_variance = extract(params["posterior_variance"], t, x.shape) |
|||
|
|||
if t[0] > 0: |
|||
noise = torch.randn_like(x) |
|||
return pred_mean + torch.sqrt(posterior_variance) * noise |
|||
else: |
|||
return pred_mean |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def sample_images(model, image_size, batch_size, channels, device, params): |
|||
"""Generate new images using the trained model""" |
|||
model.eval() |
|||
|
|||
# Start from pure noise |
|||
x = torch.randn(batch_size, channels, image_size, image_size).to(device) |
|||
|
|||
# Gradually denoise the image |
|||
for t in tqdm(reversed(range(TIMESTEPS)), desc="Sampling", total=TIMESTEPS): |
|||
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) |
|||
x = p_sample(model, x, t_batch, params) |
|||
|
|||
if x.isnan().any(): |
|||
raise ValueError(f"NaN detected in image at timestep {t}") |
|||
|
|||
return x |
|||
|
|||
|
|||
def show_images(images, title=""): |
|||
"""Display a batch of images in a grid""" |
|||
if isinstance(images, torch.Tensor): |
|||
images = images.detach().cpu().numpy() |
|||
|
|||
images = (images * 0.25 + 0.5).clip(0, 1) |
|||
plt.figure(figsize=(10, 10)) |
|||
for idx in range(min(16, len(images))): |
|||
plt.subplot(4, 4, idx + 1) |
|||
plt.imshow(np.transpose(images[idx], (1, 2, 0))) |
|||
plt.axis("off") |
|||
plt.suptitle(title) |
|||
plt.show() |
|||
|
|||
|
|||
params = get_diffusion_params(TIMESTEPS, device) |
|||
|
|||
model = UNet(32, TIMESTEPS).to(device) |
|||
model.load_state_dict(torch.load("model.pkl", weights_only=True)) |
|||
|
|||
model.eval() |
|||
generated_images = sample_images( |
|||
model=model, |
|||
image_size=IMAGE_SIZE, |
|||
batch_size=16, |
|||
channels=CHANNELS, |
|||
device=device, |
|||
params=params, |
|||
) |
|||
show_images(generated_images, title="Generated Images") |
|||
@ -0,0 +1,128 @@ |
|||
import torch |
|||
from torchvision import datasets, transforms |
|||
from torch.utils.data import DataLoader |
|||
from tqdm import tqdm |
|||
from model import UNet |
|||
|
|||
# Set random seed for reproducibility |
|||
torch.manual_seed(42) |
|||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|||
|
|||
# Hyperparameters |
|||
NUM_EPOCHS = 128 |
|||
BATCH_SIZE = 128 |
|||
IMAGE_SIZE = 32 |
|||
CHANNELS = 3 |
|||
TIMESTEPS = 1000 |
|||
|
|||
# Data loading and preprocessing |
|||
transform = transforms.Compose( |
|||
[ |
|||
transforms.Resize(IMAGE_SIZE), |
|||
transforms.RandomHorizontalFlip(0.5), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), |
|||
] |
|||
) |
|||
|
|||
# Download and load CIFAR-10 |
|||
train_dataset = datasets.CIFAR10( |
|||
root="./data", train=True, download=True, transform=transform |
|||
) |
|||
|
|||
train_loader = DataLoader( |
|||
train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True |
|||
) |
|||
|
|||
|
|||
# Define beta schedule (linear schedule as an example) |
|||
def linear_beta_schedule(timesteps): |
|||
beta_start = 0.0001 |
|||
beta_end = 0.02 |
|||
return torch.linspace(beta_start, beta_end, timesteps) |
|||
|
|||
|
|||
# Calculate diffusion parameters |
|||
def get_diffusion_params(timesteps, device): |
|||
betas = linear_beta_schedule(timesteps) |
|||
alphas = 1.0 - betas |
|||
alphas_cumprod = torch.cumprod(alphas, axis=0) |
|||
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) |
|||
|
|||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) |
|||
|
|||
# Calculate posterior mean coefficients |
|||
one_over_alphas = 1.0 / torch.sqrt(alphas) |
|||
posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod |
|||
|
|||
# Posterior variance |
|||
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
|||
|
|||
return { |
|||
"betas": betas.to(device), |
|||
"alphas_cumprod": alphas_cumprod.to(device), |
|||
"posterior_variance": posterior_variance.to(device), |
|||
"one_over_alphas": one_over_alphas.to(device), |
|||
"posterior_mean_coef": posterior_mean_coef.to(device), |
|||
} |
|||
|
|||
|
|||
# Utility functions |
|||
def extract(a, t, x_shape): |
|||
"""Extract coefficients at specified timesteps t""" |
|||
batch_size = t.shape[0] |
|||
out = a.gather(-1, t) |
|||
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) |
|||
|
|||
|
|||
# Training utilities |
|||
def get_loss_fn(model, params): |
|||
def loss_fn(x_0): |
|||
batch_size = x_0.shape[0] |
|||
t = torch.randint(0, TIMESTEPS, (batch_size,), device=device) |
|||
noise = torch.randn_like(x_0) |
|||
|
|||
# Get the noisy image |
|||
alpha_cumprod = extract(params["alphas_cumprod"], t, x_0.shape) |
|||
noise_level = torch.sqrt(1.0 - alpha_cumprod) |
|||
x_noisy = torch.sqrt(alpha_cumprod) * x_0 + noise_level * noise |
|||
|
|||
# Get predicted noise |
|||
predicted_noise = model(x_noisy, t) |
|||
|
|||
return torch.nn.functional.mse_loss(predicted_noise, noise) |
|||
|
|||
return loss_fn |
|||
|
|||
|
|||
# Training loop template |
|||
def train_epoch(model, optimizer, train_loader, loss_fn): |
|||
model.train() |
|||
total_loss = 0 |
|||
|
|||
with tqdm(train_loader, leave=False) as pbar: |
|||
for batch in pbar: |
|||
images = batch[0].to(device) |
|||
optimizer.zero_grad() |
|||
|
|||
loss = loss_fn(images) |
|||
loss.backward() |
|||
optimizer.step() |
|||
|
|||
total_loss += loss.item() |
|||
pbar.set_description(f"Loss: {loss.item():.4f}") |
|||
|
|||
return total_loss / len(train_loader) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
model = UNet(32, TIMESTEPS).to(device) |
|||
model_avg = torch.optim.swa_utils.AveragedModel( |
|||
model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.9999) |
|||
) |
|||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95)) |
|||
params = get_diffusion_params(TIMESTEPS, device) |
|||
loss_fn = get_loss_fn(model, params) |
|||
for e in tqdm(range(NUM_EPOCHS)): |
|||
train_epoch(model, optimizer, train_loader, loss_fn) |
|||
torch.save(model.state_dict(), "model.pkl") |
|||
Loading…
Reference in new issue