Browse Source

feat: training and sampling

master
Ramon Calvo 1 year ago
parent
commit
c7e1c8220f
  1. 79
      src/simple_ddpm/model.py
  2. 13
      src/simple_ddpm/sample.py
  3. 22
      src/simple_ddpm/train.py

79
src/simple_ddpm/model.py

@ -23,18 +23,54 @@ class Conv(nn.Module):
def __init__(self, in_channels: int, out_channels: int): def __init__(self, in_channels: int, out_channels: int):
super(Conv, self).__init__() super(Conv, self).__init__()
self.conv = nn.Sequential( self.in_channels = in_channels
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=out_channels), self.t_emb_layer = nn.Sequential(
nn.SiLU(), nn.Linear(T_EMBEDDING_SIZE, in_channels)
)
self.conv = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, in_channels),
nn.SiLU(),
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
),
nn.Sequential(
nn.GroupNorm(8, in_channels),
nn.SiLU(), nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=out_channels), ),
# nn.SiLU(), ]
nn.Dropout(0.1),
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: self.sa_norm = nn.GroupNorm(8, in_channels)
return self.conv(x) self.sa = nn.MultiheadAttention(in_channels, 4, dropout=0.1, batch_first=True)
self.out_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Input:
x (torch.Tensor): input of shape (B, C, H, W)
t_emb (torch.Tensor): embedding time input (B, t_emb)
"""
B, _, H, W = x.shape
x_res = self.conv[0](x)
x_res = x_res + self.t_emb_layer(t_emb).view(B, self.in_channels, 1, 1).expand(
B, self.in_channels, H, W
)
x = x + self.conv[1](x_res)
C = x.shape[1]
in_att = self.sa_norm(x.reshape(B, C, -1)).transpose(1, 2)
out_att, _ = self.sa(in_att, in_att, in_att)
out_att = out_att.transpose(1, 2).reshape(B, C, H, W)
x = x + out_att
x = self.out_conv(x)
return x
class Downsample(nn.Module): class Downsample(nn.Module):
@ -76,9 +112,11 @@ class UNet(nn.Module):
sinusoidal_positional_embedding(nb_timesteps, T_EMBEDDING_SIZE), sinusoidal_positional_embedding(nb_timesteps, T_EMBEDDING_SIZE),
) )
self.preconv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
self.encoder = nn.ModuleList( self.encoder = nn.ModuleList(
[ [
Conv(in_channels + T_EMBEDDING_SIZE, 64), Conv(64, 64),
Conv(64, 128), Conv(64, 128),
Conv(128, 256), Conv(128, 256),
Conv(256, 512), Conv(256, 512),
@ -94,32 +132,33 @@ class UNet(nn.Module):
Conv(2 * 512, 256), Conv(2 * 512, 256),
Conv(2 * 256, 128), Conv(2 * 256, 128),
Conv(2 * 128, 64), Conv(2 * 128, 64),
Conv(2 * 64, out_channels), Conv(2 * 64, 32),
] ]
) )
self.upsamplers = nn.ModuleList( self.upsamplers = nn.ModuleList(
[Upsample(512), Upsample(256), Upsample(128), Upsample(64)] [Upsample(512), Upsample(256), Upsample(128), Upsample(64)]
) )
self.out_conv = nn.Conv2d(32, 3, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
B, H, W = x.shape[0], x.shape[2], x.shape[3] x = self.preconv(x)
t_embedding = self.position_embeddings[t] # (B, T_DIM)
t_embedding = t_embedding.view(B, T_EMBEDDING_SIZE, 1, 1).expand( t_emb = self.position_embeddings[t] # (B, T_DIM)
B, T_EMBEDDING_SIZE, H, W
)
x = torch.cat((t_embedding, x), axis=1)
intermediates = [] intermediates = []
for enc, dow in zip(self.encoder, self.downsamplers): for enc, dow in zip(self.encoder, self.downsamplers):
x = enc(x) x = enc(x, t_emb)
intermediates.append(x) # TODO: Is this a copy? intermediates.append(x) # TODO: Is this a copy?
x = dow(x) x = dow(x)
x = self.bottleneck(x) x = self.bottleneck(x, t_emb)
for dec, up, m in zip(self.decoder, self.upsamplers, reversed(intermediates)): for dec, up, m in zip(self.decoder, self.upsamplers, reversed(intermediates)):
x = up(x) x = up(x)
x = torch.concat((x, m), axis=1) # Channel dimension x = torch.concat((x, m), axis=1) # Channel dimension
x = dec(x) x = dec(x, t_emb)
x = self.out_conv(x)
return x return x

13
src/simple_ddpm/sample.py

@ -32,15 +32,13 @@ def p_sample(model, x, t, params):
@torch.no_grad() @torch.no_grad()
def sample_images(model, image_size, batch_size, channels, device, params): def sample_images(model, image_size, batch_size, channels, device, params):
"""Generate new images using the trained model""" """Generate new images using the trained model"""
model.eval()
# Start from pure noise
x = torch.randn(batch_size, channels, image_size, image_size).to(device) x = torch.randn(batch_size, channels, image_size, image_size).to(device)
# Gradually denoise the image
for t in tqdm(reversed(range(TIMESTEPS)), desc="Sampling", total=TIMESTEPS): for t in tqdm(reversed(range(TIMESTEPS)), desc="Sampling", total=TIMESTEPS):
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
x = p_sample(model, x, t_batch, params) x = p_sample(model, x, t_batch, params)
if t % 100 == 0:
show_images(x)
if x.isnan().any(): if x.isnan().any():
raise ValueError(f"NaN detected in image at timestep {t}") raise ValueError(f"NaN detected in image at timestep {t}")
@ -54,15 +52,17 @@ def show_images(images, title=""):
images = images.detach().cpu().numpy() images = images.detach().cpu().numpy()
images = (images * 0.25 + 0.5).clip(0, 1) images = (images * 0.25 + 0.5).clip(0, 1)
plt.figure(figsize=(10, 10))
for idx in range(min(16, len(images))): for idx in range(min(16, len(images))):
plt.subplot(4, 4, idx + 1) plt.subplot(4, 4, idx + 1)
plt.imshow(np.transpose(images[idx], (1, 2, 0))) plt.imshow(np.transpose(images[idx], (1, 2, 0)))
plt.axis("off") plt.axis("off")
plt.suptitle(title) plt.suptitle(title)
plt.show() plt.draw()
plt.pause(0.001)
plt.figure(figsize=(10, 10))
params = get_diffusion_params(TIMESTEPS, device) params = get_diffusion_params(TIMESTEPS, device)
model = UNet(32, TIMESTEPS).to(device) model = UNet(32, TIMESTEPS).to(device)
@ -78,3 +78,4 @@ generated_images = sample_images(
params=params, params=params,
) )
show_images(generated_images, title="Generated Images") show_images(generated_images, title="Generated Images")
plt.show()

22
src/simple_ddpm/train.py

@ -9,12 +9,17 @@ torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters # Hyperparameters
NUM_EPOCHS = 128 NUM_EPOCHS = 256
BATCH_SIZE = 128 BATCH_SIZE = 128
IMAGE_SIZE = 32 IMAGE_SIZE = 32
CHANNELS = 3 CHANNELS = 3
TIMESTEPS = 1000 TIMESTEPS = 1000
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Data loading and preprocessing # Data loading and preprocessing
transform = transforms.Compose( transform = transforms.Compose(
[ [
@ -31,31 +36,27 @@ train_dataset = datasets.CIFAR10(
) )
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True
) )
# Define beta schedule (linear schedule as an example)
def linear_beta_schedule(timesteps): def linear_beta_schedule(timesteps):
beta_start = 0.0001 beta_start = 0.0001
beta_end = 0.02 beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps) return torch.linspace(beta_start, beta_end, timesteps)
# Calculate diffusion parameters
def get_diffusion_params(timesteps, device): def get_diffusion_params(timesteps, device):
betas = linear_beta_schedule(timesteps) betas = linear_beta_schedule(timesteps)
alphas = 1.0 - betas alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
# Calculate posterior mean coefficients
one_over_alphas = 1.0 / torch.sqrt(alphas) one_over_alphas = 1.0 / torch.sqrt(alphas)
posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod
# Posterior variance
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
return { return {
@ -67,7 +68,6 @@ def get_diffusion_params(timesteps, device):
} }
# Utility functions
def extract(a, t, x_shape): def extract(a, t, x_shape):
"""Extract coefficients at specified timesteps t""" """Extract coefficients at specified timesteps t"""
batch_size = t.shape[0] batch_size = t.shape[0]
@ -75,7 +75,6 @@ def extract(a, t, x_shape):
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
# Training utilities
def get_loss_fn(model, params): def get_loss_fn(model, params):
def loss_fn(x_0): def loss_fn(x_0):
batch_size = x_0.shape[0] batch_size = x_0.shape[0]
@ -117,9 +116,8 @@ def train_epoch(model, optimizer, train_loader, loss_fn):
if __name__ == "__main__": if __name__ == "__main__":
model = UNet(32, TIMESTEPS).to(device) model = UNet(32, TIMESTEPS).to(device)
model_avg = torch.optim.swa_utils.AveragedModel( nb_params = count_parameters(model)
model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.9999) print(f"Total number of parameters: {nb_params}")
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95)) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95))
params = get_diffusion_params(TIMESTEPS, device) params = get_diffusion_params(TIMESTEPS, device)
loss_fn = get_loss_fn(model, params) loss_fn = get_loss_fn(model, params)

Loading…
Cancel
Save