Browse Source

fix: use correct mean and std for sampling

master
Ramon Calvo 1 year ago
parent
commit
eb1ec55f66
  1. 15
      src/tiny_ddpm/sample.py
  2. 5
      src/tiny_ddpm/train.py

15
src/tiny_ddpm/sample.py

@ -1,10 +1,10 @@
from typing import Dict, Union
from typing import Dict
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from train import extract, get_diffusion_params
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS, NORM_MEAN, NORM_STD
from model import UNet
torch.manual_seed(1)
@ -132,12 +132,13 @@ def ddim_sample_images(
return x
def show_images(images: Union[torch.Tensor, np.array], title=""):
def show_images(images: torch.Tensor, title=""):
"""Display a batch of images in a grid"""
if isinstance(images, torch.Tensor):
images = images.detach().cpu().numpy()
mean = torch.tensor(NORM_MEAN).to(device).view(1, 3, 1, 1).expand_as(images)
std = torch.tensor(NORM_STD).to(device).view(1, 3, 1, 1).expand_as(images)
images = (images * std + mean).clip(0, 1)
images = images.detach().cpu().numpy()
images = (images * 0.25 + 0.5).clip(0, 1)
for idx in range(min(16, len(images))):
plt.subplot(4, 4, idx + 1)
plt.imshow(np.transpose(images[idx], (1, 2, 0)))
@ -157,7 +158,7 @@ if __name__ == "__main__":
model.eval()
generated_images = (
ddpm_sample_images( # change to ddim_sample_images here to enable DDIM
ddim_sample_images( # change to ddim_sample_images here to enable DDIM
model=model,
image_size=IMAGE_SIZE,
batch_size=16,

5
src/tiny_ddpm/train.py

@ -17,6 +17,9 @@ CHANNELS = 3
TIMESTEPS = 1000
DDIM_TIMESTEPS = 100
NORM_MEAN = (0.4914, 0.4822, 0.4465)
NORM_STD = (0.2470, 0.2435, 0.2616)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
@ -27,7 +30,7 @@ transform = transforms.Compose(
transforms.Resize(IMAGE_SIZE),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
transforms.Normalize(NORM_MEAN, NORM_STD),
]
)

Loading…
Cancel
Save