Browse Source

fix: use correct mean and std for sampling

master
Ramon Calvo 1 year ago
parent
commit
eb1ec55f66
  1. 13
      src/tiny_ddpm/sample.py
  2. 5
      src/tiny_ddpm/train.py

13
src/tiny_ddpm/sample.py

@ -1,10 +1,10 @@
from typing import Dict, Union from typing import Dict
import torch import torch
from tqdm import tqdm from tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from train import extract, get_diffusion_params from train import extract, get_diffusion_params
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS, NORM_MEAN, NORM_STD
from model import UNet from model import UNet
torch.manual_seed(1) torch.manual_seed(1)
@ -132,12 +132,13 @@ def ddim_sample_images(
return x return x
def show_images(images: Union[torch.Tensor, np.array], title=""): def show_images(images: torch.Tensor, title=""):
"""Display a batch of images in a grid""" """Display a batch of images in a grid"""
if isinstance(images, torch.Tensor): mean = torch.tensor(NORM_MEAN).to(device).view(1, 3, 1, 1).expand_as(images)
std = torch.tensor(NORM_STD).to(device).view(1, 3, 1, 1).expand_as(images)
images = (images * std + mean).clip(0, 1)
images = images.detach().cpu().numpy() images = images.detach().cpu().numpy()
images = (images * 0.25 + 0.5).clip(0, 1)
for idx in range(min(16, len(images))): for idx in range(min(16, len(images))):
plt.subplot(4, 4, idx + 1) plt.subplot(4, 4, idx + 1)
plt.imshow(np.transpose(images[idx], (1, 2, 0))) plt.imshow(np.transpose(images[idx], (1, 2, 0)))
@ -157,7 +158,7 @@ if __name__ == "__main__":
model.eval() model.eval()
generated_images = ( generated_images = (
ddpm_sample_images( # change to ddim_sample_images here to enable DDIM ddim_sample_images( # change to ddim_sample_images here to enable DDIM
model=model, model=model,
image_size=IMAGE_SIZE, image_size=IMAGE_SIZE,
batch_size=16, batch_size=16,

5
src/tiny_ddpm/train.py

@ -17,6 +17,9 @@ CHANNELS = 3
TIMESTEPS = 1000 TIMESTEPS = 1000
DDIM_TIMESTEPS = 100 DDIM_TIMESTEPS = 100
NORM_MEAN = (0.4914, 0.4822, 0.4465)
NORM_STD = (0.2470, 0.2435, 0.2616)
def count_parameters(model): def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) return sum(p.numel() for p in model.parameters() if p.requires_grad)
@ -27,7 +30,7 @@ transform = transforms.Compose(
transforms.Resize(IMAGE_SIZE), transforms.Resize(IMAGE_SIZE),
transforms.RandomHorizontalFlip(0.5), transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), transforms.Normalize(NORM_MEAN, NORM_STD),
] ]
) )

Loading…
Cancel
Save