diff --git a/src/tiny_ddpm/sample.py b/src/tiny_ddpm/sample.py index fd1cc24..de8bd87 100644 --- a/src/tiny_ddpm/sample.py +++ b/src/tiny_ddpm/sample.py @@ -1,10 +1,10 @@ -from typing import Dict, Union +from typing import Dict import torch from tqdm import tqdm import matplotlib.pyplot as plt import numpy as np from train import extract, get_diffusion_params -from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS +from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS, NORM_MEAN, NORM_STD from model import UNet torch.manual_seed(1) @@ -132,12 +132,13 @@ def ddim_sample_images( return x -def show_images(images: Union[torch.Tensor, np.array], title=""): +def show_images(images: torch.Tensor, title=""): """Display a batch of images in a grid""" - if isinstance(images, torch.Tensor): - images = images.detach().cpu().numpy() + mean = torch.tensor(NORM_MEAN).to(device).view(1, 3, 1, 1).expand_as(images) + std = torch.tensor(NORM_STD).to(device).view(1, 3, 1, 1).expand_as(images) + images = (images * std + mean).clip(0, 1) + images = images.detach().cpu().numpy() - images = (images * 0.25 + 0.5).clip(0, 1) for idx in range(min(16, len(images))): plt.subplot(4, 4, idx + 1) plt.imshow(np.transpose(images[idx], (1, 2, 0))) @@ -157,7 +158,7 @@ if __name__ == "__main__": model.eval() generated_images = ( - ddpm_sample_images( # change to ddim_sample_images here to enable DDIM + ddim_sample_images( # change to ddim_sample_images here to enable DDIM model=model, image_size=IMAGE_SIZE, batch_size=16, diff --git a/src/tiny_ddpm/train.py b/src/tiny_ddpm/train.py index 38e0377..2d6ffbb 100644 --- a/src/tiny_ddpm/train.py +++ b/src/tiny_ddpm/train.py @@ -17,6 +17,9 @@ CHANNELS = 3 TIMESTEPS = 1000 DDIM_TIMESTEPS = 100 +NORM_MEAN = (0.4914, 0.4822, 0.4465) +NORM_STD = (0.2470, 0.2435, 0.2616) + def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -27,7 +30,7 @@ transform = transforms.Compose( transforms.Resize(IMAGE_SIZE), transforms.RandomHorizontalFlip(0.5), transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + transforms.Normalize(NORM_MEAN, NORM_STD), ] )