diff --git a/README.md b/README.md index ddca671..385b7ca 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Tiny DDPM +# Tiny DDPM (and DDIM) This is a bare bones and simple DDPM ([Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239)) implementation on PyTorch. The whole implementation (model + training + sampling) does not exceed 400 lines of code. The training setup and U-Net model loosely resemble the description of the original paper, but it is not a 1 to 1 implementation. @@ -23,9 +23,13 @@ python -m pip install . # Otherwise uv run src/simple_ddpm/train.py # If using uv python src/simple_ddpm/train.py # Otherwise ``` -## Training +## Sampling ```bash uv run src/simple_ddpm/sample.py # If using uv python src/simple_ddpm/sample.py # Otherwise ``` + +By default it will perform DDPM sampling. If you want to use DDIM, simply +change the function called at the bottom of `src/tiny-ddpm/sample.py` to call +`ddim_sample_images` instead of `ddpm_sample_images`. diff --git a/src/simple_ddpm/sample.py b/src/simple_ddpm/sample.py deleted file mode 100644 index 427118a..0000000 --- a/src/simple_ddpm/sample.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import Dict, Union -import torch -from tqdm import tqdm -import matplotlib.pyplot as plt -import numpy as np -from train import extract, get_diffusion_params -from train import TIMESTEPS, IMAGE_SIZE, CHANNELS -from model import UNet - -torch.manual_seed(1) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -@torch.no_grad() -def p_sample( - model, x: torch.Tensor, t: torch.Tensor, params: Dict[str, torch.Tensor] -) -> torch.Tensor: - """Sample from the model at timestep t""" - predicted_noise = model(x, t) - - one_over_alphas = extract(params["one_over_alphas"], t, x.shape) - posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) - - pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) - - posterior_variance = extract(params["posterior_variance"], t, x.shape) - - if t[0] > 0: - noise = torch.randn_like(x) - return pred_mean + torch.sqrt(posterior_variance) * noise - else: - return pred_mean - - -@torch.no_grad() -def sample_images( - model: torch.nn.Module, - image_size: int, - batch_size: int, - channels: int, - device: torch.device, - params: Dict[str, torch.Tensor], -): - """Generate new images using the trained model""" - x = torch.randn(batch_size, channels, image_size, image_size).to(device) - - for t in tqdm(reversed(range(TIMESTEPS)), desc="Sampling", total=TIMESTEPS): - t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) - x = p_sample(model, x, t_batch, params) - if t % 100 == 0: - show_images(x) - - if x.isnan().any(): - raise ValueError(f"NaN detected in image at timestep {t}") - - return x - - -def show_images(images: Union[torch.Tensor, np.array], title=""): - """Display a batch of images in a grid""" - if isinstance(images, torch.Tensor): - images = images.detach().cpu().numpy() - - images = (images * 0.25 + 0.5).clip(0, 1) - for idx in range(min(16, len(images))): - plt.subplot(4, 4, idx + 1) - plt.imshow(np.transpose(images[idx], (1, 2, 0))) - plt.axis("off") - plt.suptitle(title) - plt.draw() - plt.pause(0.001) - - -if __name__ == "__main__": - plt.figure(figsize=(10, 10)) - - params = get_diffusion_params(TIMESTEPS, device) - - model = UNet(32, TIMESTEPS).to(device) - model.load_state_dict(torch.load("model.pkl", weights_only=True)) - - model.eval() - generated_images = sample_images( - model=model, - image_size=IMAGE_SIZE, - batch_size=16, - channels=CHANNELS, - device=device, - params=params, - ) - show_images(generated_images, title="Generated Images") - - # Keep the plot open after generation is finished - plt.show() diff --git a/src/simple_ddpm/model.py b/src/tiny_ddpm/model.py similarity index 100% rename from src/simple_ddpm/model.py rename to src/tiny_ddpm/model.py diff --git a/src/tiny_ddpm/sample.py b/src/tiny_ddpm/sample.py new file mode 100644 index 0000000..8528151 --- /dev/null +++ b/src/tiny_ddpm/sample.py @@ -0,0 +1,173 @@ +from typing import Dict, Union +import torch +from tqdm import tqdm +import matplotlib.pyplot as plt +import numpy as np +from train import extract, get_diffusion_params +from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS +from model import UNet + +torch.manual_seed(1) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +@torch.no_grad() +def ddpm_sample( + model: torch.nn.Module, + x: torch.Tensor, + t: torch.Tensor, + params: Dict[str, torch.Tensor], +) -> torch.Tensor: + """Sample from the model at timestep t""" + predicted_noise = model(x, t) + + one_over_alphas = extract(params["one_over_alphas"], t, x.shape) + posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) + + pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) + + posterior_variance = extract(params["posterior_variance"], t, x.shape) + + if t[0] > 0: + noise = torch.randn_like(x) + return pred_mean + torch.sqrt(posterior_variance) * noise + else: + return pred_mean + + +@torch.no_grad() +def ddim_sample( + model: torch.nn.Module, + x: torch.Tensor, + t: torch.Tensor, + params: Dict[str, torch.Tensor], +) -> torch.Tensor: + """Sample from the model in a non-markovian way (DDIM)""" + stride = TIMESTEPS // DDIM_TIMESTEPS + t_prev = t - stride + predicted_noise = model(x, t) + + alphas_prod = extract(params["alphas_cumprod"], t, x.shape) + valid_mask = (t_prev >= 0).view(-1, 1, 1, 1) + safe_t_prev = torch.maximum(t_prev, torch.tensor(0, device=device)) + alphas_prod_prev = extract(params["alphas_cumprod"], safe_t_prev, x.shape) + alphas_prod_prev = torch.where( + valid_mask, alphas_prod_prev, torch.ones_like(alphas_prod_prev) + ) + + sigma = extract(params["ddim_sigma"], t, x.shape) + + pred_x0 = (x - (1 - alphas_prod).sqrt() * predicted_noise) / alphas_prod.sqrt() + + pred = ( + alphas_prod_prev.sqrt() * pred_x0 + + (1.0 - alphas_prod_prev).sqrt() * predicted_noise + ) + + if t[0] > 0: + noise = torch.randn_like(x) + pred = pred + noise * sigma + + return pred + + +@torch.no_grad() +def ddpm_sample_images( + model: torch.nn.Module, + image_size: int, + batch_size: int, + channels: int, + device: torch.device, + params: Dict[str, torch.Tensor], +): + """Generate new images using the trained model""" + x = torch.randn(batch_size, channels, image_size, image_size).to(device) + + for t in tqdm(reversed(range(TIMESTEPS)), desc="DDPM Sampling", total=TIMESTEPS): + t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) + x = ddpm_sample(model, x, t_batch, params) + if t % 100 == 0: + show_images(x) + + if x.isnan().any(): + raise ValueError(f"NaN detected in image at timestep {t}") + + return x + + +def get_ddim_timesteps( + total_timesteps: int, num_sampling_timesteps: int +) -> torch.Tensor: + """Gets the timesteps used for the DDIM process.""" + assert total_timesteps % num_sampling_timesteps == 0 + stride = total_timesteps // num_sampling_timesteps + timesteps = torch.arange(0, total_timesteps, stride) + return timesteps.flip(0) + + +@torch.no_grad() +def ddim_sample_images( + model: torch.nn.Module, + image_size: int, + batch_size: int, + channels: int, + device: torch.device, + params: Dict[str, torch.Tensor], +): + """Generate new images using the trained model""" + x = torch.randn(batch_size, channels, image_size, image_size).to(device) + + timesteps = get_ddim_timesteps(TIMESTEPS, DDIM_TIMESTEPS) + + for i in tqdm(range(len(timesteps) - 1), desc="DDIM Sampling"): + t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long) + x_before = x.clone() + x = ddim_sample(model, x, t, params) + print(f"Step {i}, max diff: {(x - x_before).abs().max().item()}") + + if x.isnan().any(): + raise ValueError(f"NaN detected at timestep {timesteps[i]}") + + if i % 10 == 0: + show_images(x) + return x + + +def show_images(images: Union[torch.Tensor, np.array], title=""): + """Display a batch of images in a grid""" + if isinstance(images, torch.Tensor): + images = images.detach().cpu().numpy() + + images = (images * 0.25 + 0.5).clip(0, 1) + for idx in range(min(16, len(images))): + plt.subplot(4, 4, idx + 1) + plt.imshow(np.transpose(images[idx], (1, 2, 0))) + plt.axis("off") + plt.suptitle(title) + plt.draw() + plt.pause(0.001) + + +if __name__ == "__main__": + plt.figure(figsize=(10, 10)) + + params = get_diffusion_params(TIMESTEPS, device, eta=0.0) + + model = UNet(32, TIMESTEPS).to(device) + model.load_state_dict(torch.load("model.pkl", weights_only=True)) + + model.eval() + generated_images = ( + ddpm_sample_images( # change to ddim_sample_images here to enable DDIM + model=model, + image_size=IMAGE_SIZE, + batch_size=16, + channels=CHANNELS, + device=device, + params=params, + ) + ) + show_images(generated_images, title="Generated Images") + + # Keep the plot open after generation is finished + plt.show() diff --git a/src/simple_ddpm/train.py b/src/tiny_ddpm/train.py similarity index 91% rename from src/simple_ddpm/train.py rename to src/tiny_ddpm/train.py index 3729c05..38e0377 100644 --- a/src/simple_ddpm/train.py +++ b/src/tiny_ddpm/train.py @@ -15,6 +15,7 @@ BATCH_SIZE = 128 IMAGE_SIZE = 32 CHANNELS = 3 TIMESTEPS = 1000 +DDIM_TIMESTEPS = 100 def count_parameters(model): @@ -40,7 +41,10 @@ train_loader = DataLoader( def get_diffusion_params( - timesteps: int, device: torch.device + timesteps: int, + device: torch.device, + ddim_timesteps: int = DDIM_TIMESTEPS, + eta=0.0, ) -> Dict[str, torch.Tensor]: def linear_beta_schedule(timesteps): beta_start = 0.0001 @@ -59,12 +63,21 @@ def get_diffusion_params( posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ddim_sigma = eta * torch.sqrt( + (1.0 - alphas_cumprod_prev) + / (1.0 - alphas_cumprod) + * (1 - alphas_cumprod / alphas_cumprod_prev) + ) + return { + # DDPM Parameters "betas": betas.to(device), "alphas_cumprod": alphas_cumprod.to(device), "posterior_variance": posterior_variance.to(device), "one_over_alphas": one_over_alphas.to(device), "posterior_mean_coef": posterior_mean_coef.to(device), + # DDIM Parameters + "ddim_sigma": ddim_sigma.to(device), } @@ -94,7 +107,6 @@ def get_loss_fn(model: torch.nn.Module, params: Dict[str, torch.Tensor]) -> Call return loss_fn -# Training loop template def train_epoch( model: torch.nn.Module, optimize, train_loader: DataLoader, loss_fn: Callable ) -> float: