Browse Source

feat: training and sampling

master
Ramon Calvo 1 year ago
parent
commit
c7e1c8220f
  1. 79
      src/simple_ddpm/model.py
  2. 13
      src/simple_ddpm/sample.py
  3. 22
      src/simple_ddpm/train.py

79
src/simple_ddpm/model.py

@ -23,18 +23,54 @@ class Conv(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super(Conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=out_channels),
self.in_channels = in_channels
self.t_emb_layer = nn.Sequential(
nn.SiLU(), nn.Linear(T_EMBEDDING_SIZE, in_channels)
)
self.conv = nn.ModuleList(
[
nn.Sequential(
nn.GroupNorm(8, in_channels),
nn.SiLU(),
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
),
nn.Sequential(
nn.GroupNorm(8, in_channels),
nn.SiLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(num_features=out_channels),
# nn.SiLU(),
nn.Dropout(0.1),
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
),
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.conv(x)
self.sa_norm = nn.GroupNorm(8, in_channels)
self.sa = nn.MultiheadAttention(in_channels, 4, dropout=0.1, batch_first=True)
self.out_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Input:
x (torch.Tensor): input of shape (B, C, H, W)
t_emb (torch.Tensor): embedding time input (B, t_emb)
"""
B, _, H, W = x.shape
x_res = self.conv[0](x)
x_res = x_res + self.t_emb_layer(t_emb).view(B, self.in_channels, 1, 1).expand(
B, self.in_channels, H, W
)
x = x + self.conv[1](x_res)
C = x.shape[1]
in_att = self.sa_norm(x.reshape(B, C, -1)).transpose(1, 2)
out_att, _ = self.sa(in_att, in_att, in_att)
out_att = out_att.transpose(1, 2).reshape(B, C, H, W)
x = x + out_att
x = self.out_conv(x)
return x
class Downsample(nn.Module):
@ -76,9 +112,11 @@ class UNet(nn.Module):
sinusoidal_positional_embedding(nb_timesteps, T_EMBEDDING_SIZE),
)
self.preconv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
self.encoder = nn.ModuleList(
[
Conv(in_channels + T_EMBEDDING_SIZE, 64),
Conv(64, 64),
Conv(64, 128),
Conv(128, 256),
Conv(256, 512),
@ -94,32 +132,33 @@ class UNet(nn.Module):
Conv(2 * 512, 256),
Conv(2 * 256, 128),
Conv(2 * 128, 64),
Conv(2 * 64, out_channels),
Conv(2 * 64, 32),
]
)
self.upsamplers = nn.ModuleList(
[Upsample(512), Upsample(256), Upsample(128), Upsample(64)]
)
self.out_conv = nn.Conv2d(32, 3, kernel_size=3, padding=1)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
B, H, W = x.shape[0], x.shape[2], x.shape[3]
t_embedding = self.position_embeddings[t] # (B, T_DIM)
t_embedding = t_embedding.view(B, T_EMBEDDING_SIZE, 1, 1).expand(
B, T_EMBEDDING_SIZE, H, W
)
x = torch.cat((t_embedding, x), axis=1)
x = self.preconv(x)
t_emb = self.position_embeddings[t] # (B, T_DIM)
intermediates = []
for enc, dow in zip(self.encoder, self.downsamplers):
x = enc(x)
x = enc(x, t_emb)
intermediates.append(x) # TODO: Is this a copy?
x = dow(x)
x = self.bottleneck(x)
x = self.bottleneck(x, t_emb)
for dec, up, m in zip(self.decoder, self.upsamplers, reversed(intermediates)):
x = up(x)
x = torch.concat((x, m), axis=1) # Channel dimension
x = dec(x)
x = dec(x, t_emb)
x = self.out_conv(x)
return x

13
src/simple_ddpm/sample.py

@ -32,15 +32,13 @@ def p_sample(model, x, t, params):
@torch.no_grad()
def sample_images(model, image_size, batch_size, channels, device, params):
"""Generate new images using the trained model"""
model.eval()
# Start from pure noise
x = torch.randn(batch_size, channels, image_size, image_size).to(device)
# Gradually denoise the image
for t in tqdm(reversed(range(TIMESTEPS)), desc="Sampling", total=TIMESTEPS):
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
x = p_sample(model, x, t_batch, params)
if t % 100 == 0:
show_images(x)
if x.isnan().any():
raise ValueError(f"NaN detected in image at timestep {t}")
@ -54,15 +52,17 @@ def show_images(images, title=""):
images = images.detach().cpu().numpy()
images = (images * 0.25 + 0.5).clip(0, 1)
plt.figure(figsize=(10, 10))
for idx in range(min(16, len(images))):
plt.subplot(4, 4, idx + 1)
plt.imshow(np.transpose(images[idx], (1, 2, 0)))
plt.axis("off")
plt.suptitle(title)
plt.show()
plt.draw()
plt.pause(0.001)
plt.figure(figsize=(10, 10))
params = get_diffusion_params(TIMESTEPS, device)
model = UNet(32, TIMESTEPS).to(device)
@ -78,3 +78,4 @@ generated_images = sample_images(
params=params,
)
show_images(generated_images, title="Generated Images")
plt.show()

22
src/simple_ddpm/train.py

@ -9,12 +9,17 @@ torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameters
NUM_EPOCHS = 128
NUM_EPOCHS = 256
BATCH_SIZE = 128
IMAGE_SIZE = 32
CHANNELS = 3
TIMESTEPS = 1000
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# Data loading and preprocessing
transform = transforms.Compose(
[
@ -31,31 +36,27 @@ train_dataset = datasets.CIFAR10(
)
train_loader = DataLoader(
train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True
)
# Define beta schedule (linear schedule as an example)
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
# Calculate diffusion parameters
def get_diffusion_params(timesteps, device):
betas = linear_beta_schedule(timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
# Calculate posterior mean coefficients
one_over_alphas = 1.0 / torch.sqrt(alphas)
posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod
# Posterior variance
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
return {
@ -67,7 +68,6 @@ def get_diffusion_params(timesteps, device):
}
# Utility functions
def extract(a, t, x_shape):
"""Extract coefficients at specified timesteps t"""
batch_size = t.shape[0]
@ -75,7 +75,6 @@ def extract(a, t, x_shape):
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
# Training utilities
def get_loss_fn(model, params):
def loss_fn(x_0):
batch_size = x_0.shape[0]
@ -117,9 +116,8 @@ def train_epoch(model, optimizer, train_loader, loss_fn):
if __name__ == "__main__":
model = UNet(32, TIMESTEPS).to(device)
model_avg = torch.optim.swa_utils.AveragedModel(
model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.9999)
)
nb_params = count_parameters(model)
print(f"Total number of parameters: {nb_params}")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95))
params = get_diffusion_params(TIMESTEPS, device)
loss_fn = get_loss_fn(model, params)

Loading…
Cancel
Save