diff --git a/src/simple_ddpm/model.py b/src/simple_ddpm/model.py index 6a722ec..bc36424 100644 --- a/src/simple_ddpm/model.py +++ b/src/simple_ddpm/model.py @@ -23,18 +23,54 @@ class Conv(nn.Module): def __init__(self, in_channels: int, out_channels: int): super(Conv, self).__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(num_features=out_channels), - nn.SiLU(), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.BatchNorm2d(num_features=out_channels), - # nn.SiLU(), - nn.Dropout(0.1), + self.in_channels = in_channels + + self.t_emb_layer = nn.Sequential( + nn.SiLU(), nn.Linear(T_EMBEDDING_SIZE, in_channels) ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.conv(x) + self.conv = nn.ModuleList( + [ + nn.Sequential( + nn.GroupNorm(8, in_channels), + nn.SiLU(), + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + ), + nn.Sequential( + nn.GroupNorm(8, in_channels), + nn.SiLU(), + nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), + ), + ] + ) + + self.sa_norm = nn.GroupNorm(8, in_channels) + self.sa = nn.MultiheadAttention(in_channels, 4, dropout=0.1, batch_first=True) + + self.out_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + + def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor: + """Forward pass. + Input: + x (torch.Tensor): input of shape (B, C, H, W) + t_emb (torch.Tensor): embedding time input (B, t_emb) + """ + B, _, H, W = x.shape + x_res = self.conv[0](x) + x_res = x_res + self.t_emb_layer(t_emb).view(B, self.in_channels, 1, 1).expand( + B, self.in_channels, H, W + ) + x = x + self.conv[1](x_res) + + C = x.shape[1] + in_att = self.sa_norm(x.reshape(B, C, -1)).transpose(1, 2) + out_att, _ = self.sa(in_att, in_att, in_att) + out_att = out_att.transpose(1, 2).reshape(B, C, H, W) + x = x + out_att + + x = self.out_conv(x) + + return x class Downsample(nn.Module): @@ -76,9 +112,11 @@ class UNet(nn.Module): sinusoidal_positional_embedding(nb_timesteps, T_EMBEDDING_SIZE), ) + self.preconv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1) + self.encoder = nn.ModuleList( [ - Conv(in_channels + T_EMBEDDING_SIZE, 64), + Conv(64, 64), Conv(64, 128), Conv(128, 256), Conv(256, 512), @@ -94,32 +132,33 @@ class UNet(nn.Module): Conv(2 * 512, 256), Conv(2 * 256, 128), Conv(2 * 128, 64), - Conv(2 * 64, out_channels), + Conv(2 * 64, 32), ] ) self.upsamplers = nn.ModuleList( [Upsample(512), Upsample(256), Upsample(128), Upsample(64)] ) + self.out_conv = nn.Conv2d(32, 3, kernel_size=3, padding=1) + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - B, H, W = x.shape[0], x.shape[2], x.shape[3] - t_embedding = self.position_embeddings[t] # (B, T_DIM) - t_embedding = t_embedding.view(B, T_EMBEDDING_SIZE, 1, 1).expand( - B, T_EMBEDDING_SIZE, H, W - ) - x = torch.cat((t_embedding, x), axis=1) + x = self.preconv(x) + + t_emb = self.position_embeddings[t] # (B, T_DIM) intermediates = [] for enc, dow in zip(self.encoder, self.downsamplers): - x = enc(x) + x = enc(x, t_emb) intermediates.append(x) # TODO: Is this a copy? x = dow(x) - x = self.bottleneck(x) + x = self.bottleneck(x, t_emb) for dec, up, m in zip(self.decoder, self.upsamplers, reversed(intermediates)): x = up(x) x = torch.concat((x, m), axis=1) # Channel dimension - x = dec(x) + x = dec(x, t_emb) + + x = self.out_conv(x) return x diff --git a/src/simple_ddpm/sample.py b/src/simple_ddpm/sample.py index a10e163..9402f09 100644 --- a/src/simple_ddpm/sample.py +++ b/src/simple_ddpm/sample.py @@ -32,15 +32,13 @@ def p_sample(model, x, t, params): @torch.no_grad() def sample_images(model, image_size, batch_size, channels, device, params): """Generate new images using the trained model""" - model.eval() - - # Start from pure noise x = torch.randn(batch_size, channels, image_size, image_size).to(device) - # Gradually denoise the image for t in tqdm(reversed(range(TIMESTEPS)), desc="Sampling", total=TIMESTEPS): t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) x = p_sample(model, x, t_batch, params) + if t % 100 == 0: + show_images(x) if x.isnan().any(): raise ValueError(f"NaN detected in image at timestep {t}") @@ -54,14 +52,16 @@ def show_images(images, title=""): images = images.detach().cpu().numpy() images = (images * 0.25 + 0.5).clip(0, 1) - plt.figure(figsize=(10, 10)) for idx in range(min(16, len(images))): plt.subplot(4, 4, idx + 1) plt.imshow(np.transpose(images[idx], (1, 2, 0))) plt.axis("off") plt.suptitle(title) - plt.show() + plt.draw() + plt.pause(0.001) + +plt.figure(figsize=(10, 10)) params = get_diffusion_params(TIMESTEPS, device) @@ -78,3 +78,4 @@ generated_images = sample_images( params=params, ) show_images(generated_images, title="Generated Images") +plt.show() diff --git a/src/simple_ddpm/train.py b/src/simple_ddpm/train.py index c533156..ffbed39 100644 --- a/src/simple_ddpm/train.py +++ b/src/simple_ddpm/train.py @@ -9,12 +9,17 @@ torch.manual_seed(42) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Hyperparameters -NUM_EPOCHS = 128 +NUM_EPOCHS = 256 BATCH_SIZE = 128 IMAGE_SIZE = 32 CHANNELS = 3 TIMESTEPS = 1000 + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + # Data loading and preprocessing transform = transforms.Compose( [ @@ -31,31 +36,27 @@ train_dataset = datasets.CIFAR10( ) train_loader = DataLoader( - train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True + train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True ) -# Define beta schedule (linear schedule as an example) def linear_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start, beta_end, timesteps) -# Calculate diffusion parameters def get_diffusion_params(timesteps, device): betas = linear_beta_schedule(timesteps) alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) - # Calculate posterior mean coefficients one_over_alphas = 1.0 / torch.sqrt(alphas) posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod - # Posterior variance posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) return { @@ -67,7 +68,6 @@ def get_diffusion_params(timesteps, device): } -# Utility functions def extract(a, t, x_shape): """Extract coefficients at specified timesteps t""" batch_size = t.shape[0] @@ -75,7 +75,6 @@ def extract(a, t, x_shape): return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) -# Training utilities def get_loss_fn(model, params): def loss_fn(x_0): batch_size = x_0.shape[0] @@ -117,9 +116,8 @@ def train_epoch(model, optimizer, train_loader, loss_fn): if __name__ == "__main__": model = UNet(32, TIMESTEPS).to(device) - model_avg = torch.optim.swa_utils.AveragedModel( - model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.9999) - ) + nb_params = count_parameters(model) + print(f"Total number of parameters: {nb_params}") optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95)) params = get_diffusion_params(TIMESTEPS, device) loss_fn = get_loss_fn(model, params)