commit 7e3247eefcb06abfe14d426d4d22c1edd167ce2e Author: Ramon Calvo Date: Mon Oct 28 18:13:16 2024 +0100 First commit diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b61f113 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,18 @@ +[project] +name = "simple-ddpm" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "torch>=2.5.0", + "torchvision>=0.20.0", + "matplotlib>=3.9.2", + "tqdm>=4.66.5", + "numpy>=2.1.2", + "pyqt6>=6.7.1", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" diff --git a/src/simple_ddpm/model.py b/src/simple_ddpm/model.py new file mode 100644 index 0000000..6a722ec --- /dev/null +++ b/src/simple_ddpm/model.py @@ -0,0 +1,125 @@ +import torch +import torch.nn as nn +import math + + +T_EMBEDDING_SIZE = 32 + + +def sinusoidal_positional_embedding(max_seq_len, d_model): + pe = torch.zeros(max_seq_len, d_model) + position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + return pe + + +class Conv(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super(Conv, self).__init__() + + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(num_features=out_channels), + nn.SiLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(num_features=out_channels), + # nn.SiLU(), + nn.Dropout(0.1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class Downsample(nn.Module): + def __init__(self, channels: int): + super(Downsample, self).__init__() + + self.downsample = nn.Conv2d( + channels, channels, kernel_size=3, padding=1, stride=2 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.downsample(x) + + +class Upsample(nn.Module): + def __init__(self, channels: int): + super(Upsample, self).__init__() + + self.upsample = nn.ConvTranspose2d( + channels, channels, kernel_size=2, padding=0, stride=2 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.upsample(x) + + +class UNet(nn.Module): + def __init__( + self, + image_size: int, + nb_timesteps: int, + in_channels: int = 3, + out_channels: int = 3, + ): + super(UNet, self).__init__() + + self.register_buffer( + "position_embeddings", + sinusoidal_positional_embedding(nb_timesteps, T_EMBEDDING_SIZE), + ) + + self.encoder = nn.ModuleList( + [ + Conv(in_channels + T_EMBEDDING_SIZE, 64), + Conv(64, 128), + Conv(128, 256), + Conv(256, 512), + ] + ) + self.downsamplers = nn.ModuleList( + [Downsample(64), Downsample(128), Downsample(256), Downsample(512)] + ) + self.bottleneck = Conv(512, 512) + + self.decoder = nn.ModuleList( + [ + Conv(2 * 512, 256), + Conv(2 * 256, 128), + Conv(2 * 128, 64), + Conv(2 * 64, out_channels), + ] + ) + self.upsamplers = nn.ModuleList( + [Upsample(512), Upsample(256), Upsample(128), Upsample(64)] + ) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + B, H, W = x.shape[0], x.shape[2], x.shape[3] + t_embedding = self.position_embeddings[t] # (B, T_DIM) + t_embedding = t_embedding.view(B, T_EMBEDDING_SIZE, 1, 1).expand( + B, T_EMBEDDING_SIZE, H, W + ) + x = torch.cat((t_embedding, x), axis=1) + + intermediates = [] + for enc, dow in zip(self.encoder, self.downsamplers): + x = enc(x) + intermediates.append(x) # TODO: Is this a copy? + x = dow(x) + + x = self.bottleneck(x) + + for dec, up, m in zip(self.decoder, self.upsamplers, reversed(intermediates)): + x = up(x) + x = torch.concat((x, m), axis=1) # Channel dimension + x = dec(x) + + return x diff --git a/src/simple_ddpm/sample.py b/src/simple_ddpm/sample.py new file mode 100644 index 0000000..a10e163 --- /dev/null +++ b/src/simple_ddpm/sample.py @@ -0,0 +1,80 @@ +import torch +from tqdm import tqdm +import matplotlib.pyplot as plt +import numpy as np +from train import extract, get_diffusion_params +from train import TIMESTEPS, IMAGE_SIZE, CHANNELS +from model import UNet + +torch.manual_seed(42) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@torch.no_grad() +def p_sample(model, x, t, params): + """Sample from the model at timestep t""" + predicted_noise = model(x, t) + + one_over_alphas = extract(params["one_over_alphas"], t, x.shape) + posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) + + pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) + + posterior_variance = extract(params["posterior_variance"], t, x.shape) + + if t[0] > 0: + noise = torch.randn_like(x) + return pred_mean + torch.sqrt(posterior_variance) * noise + else: + return pred_mean + + +@torch.no_grad() +def sample_images(model, image_size, batch_size, channels, device, params): + """Generate new images using the trained model""" + model.eval() + + # Start from pure noise + x = torch.randn(batch_size, channels, image_size, image_size).to(device) + + # Gradually denoise the image + for t in tqdm(reversed(range(TIMESTEPS)), desc="Sampling", total=TIMESTEPS): + t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) + x = p_sample(model, x, t_batch, params) + + if x.isnan().any(): + raise ValueError(f"NaN detected in image at timestep {t}") + + return x + + +def show_images(images, title=""): + """Display a batch of images in a grid""" + if isinstance(images, torch.Tensor): + images = images.detach().cpu().numpy() + + images = (images * 0.25 + 0.5).clip(0, 1) + plt.figure(figsize=(10, 10)) + for idx in range(min(16, len(images))): + plt.subplot(4, 4, idx + 1) + plt.imshow(np.transpose(images[idx], (1, 2, 0))) + plt.axis("off") + plt.suptitle(title) + plt.show() + + +params = get_diffusion_params(TIMESTEPS, device) + +model = UNet(32, TIMESTEPS).to(device) +model.load_state_dict(torch.load("model.pkl", weights_only=True)) + +model.eval() +generated_images = sample_images( + model=model, + image_size=IMAGE_SIZE, + batch_size=16, + channels=CHANNELS, + device=device, + params=params, +) +show_images(generated_images, title="Generated Images") diff --git a/src/simple_ddpm/train.py b/src/simple_ddpm/train.py new file mode 100644 index 0000000..c533156 --- /dev/null +++ b/src/simple_ddpm/train.py @@ -0,0 +1,128 @@ +import torch +from torchvision import datasets, transforms +from torch.utils.data import DataLoader +from tqdm import tqdm +from model import UNet + +# Set random seed for reproducibility +torch.manual_seed(42) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Hyperparameters +NUM_EPOCHS = 128 +BATCH_SIZE = 128 +IMAGE_SIZE = 32 +CHANNELS = 3 +TIMESTEPS = 1000 + +# Data loading and preprocessing +transform = transforms.Compose( + [ + transforms.Resize(IMAGE_SIZE), + transforms.RandomHorizontalFlip(0.5), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ] +) + +# Download and load CIFAR-10 +train_dataset = datasets.CIFAR10( + root="./data", train=True, download=True, transform=transform +) + +train_loader = DataLoader( + train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True +) + + +# Define beta schedule (linear schedule as an example) +def linear_beta_schedule(timesteps): + beta_start = 0.0001 + beta_end = 0.02 + return torch.linspace(beta_start, beta_end, timesteps) + + +# Calculate diffusion parameters +def get_diffusion_params(timesteps, device): + betas = linear_beta_schedule(timesteps) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) + + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + + # Calculate posterior mean coefficients + one_over_alphas = 1.0 / torch.sqrt(alphas) + posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod + + # Posterior variance + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + + return { + "betas": betas.to(device), + "alphas_cumprod": alphas_cumprod.to(device), + "posterior_variance": posterior_variance.to(device), + "one_over_alphas": one_over_alphas.to(device), + "posterior_mean_coef": posterior_mean_coef.to(device), + } + + +# Utility functions +def extract(a, t, x_shape): + """Extract coefficients at specified timesteps t""" + batch_size = t.shape[0] + out = a.gather(-1, t) + return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) + + +# Training utilities +def get_loss_fn(model, params): + def loss_fn(x_0): + batch_size = x_0.shape[0] + t = torch.randint(0, TIMESTEPS, (batch_size,), device=device) + noise = torch.randn_like(x_0) + + # Get the noisy image + alpha_cumprod = extract(params["alphas_cumprod"], t, x_0.shape) + noise_level = torch.sqrt(1.0 - alpha_cumprod) + x_noisy = torch.sqrt(alpha_cumprod) * x_0 + noise_level * noise + + # Get predicted noise + predicted_noise = model(x_noisy, t) + + return torch.nn.functional.mse_loss(predicted_noise, noise) + + return loss_fn + + +# Training loop template +def train_epoch(model, optimizer, train_loader, loss_fn): + model.train() + total_loss = 0 + + with tqdm(train_loader, leave=False) as pbar: + for batch in pbar: + images = batch[0].to(device) + optimizer.zero_grad() + + loss = loss_fn(images) + loss.backward() + optimizer.step() + + total_loss += loss.item() + pbar.set_description(f"Loss: {loss.item():.4f}") + + return total_loss / len(train_loader) + + +if __name__ == "__main__": + model = UNet(32, TIMESTEPS).to(device) + model_avg = torch.optim.swa_utils.AveragedModel( + model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.9999) + ) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95)) + params = get_diffusion_params(TIMESTEPS, device) + loss_fn = get_loss_fn(model, params) + for e in tqdm(range(NUM_EPOCHS)): + train_epoch(model, optimizer, train_loader, loss_fn) + torch.save(model.state_dict(), "model.pkl")