|
|
|
@ -1,10 +1,10 @@ |
|
|
|
from typing import Dict, Union |
|
|
|
from typing import Dict |
|
|
|
import torch |
|
|
|
from tqdm import tqdm |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import numpy as np |
|
|
|
from train import extract, get_diffusion_params |
|
|
|
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS |
|
|
|
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS, NORM_MEAN, NORM_STD |
|
|
|
from model import UNet |
|
|
|
|
|
|
|
torch.manual_seed(1) |
|
|
|
@ -132,12 +132,13 @@ def ddim_sample_images( |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
def show_images(images: Union[torch.Tensor, np.array], title=""): |
|
|
|
def show_images(images: torch.Tensor, title=""): |
|
|
|
"""Display a batch of images in a grid""" |
|
|
|
if isinstance(images, torch.Tensor): |
|
|
|
images = images.detach().cpu().numpy() |
|
|
|
mean = torch.tensor(NORM_MEAN).to(device).view(1, 3, 1, 1).expand_as(images) |
|
|
|
std = torch.tensor(NORM_STD).to(device).view(1, 3, 1, 1).expand_as(images) |
|
|
|
images = (images * std + mean).clip(0, 1) |
|
|
|
images = images.detach().cpu().numpy() |
|
|
|
|
|
|
|
images = (images * 0.25 + 0.5).clip(0, 1) |
|
|
|
for idx in range(min(16, len(images))): |
|
|
|
plt.subplot(4, 4, idx + 1) |
|
|
|
plt.imshow(np.transpose(images[idx], (1, 2, 0))) |
|
|
|
@ -157,7 +158,7 @@ if __name__ == "__main__": |
|
|
|
|
|
|
|
model.eval() |
|
|
|
generated_images = ( |
|
|
|
ddpm_sample_images( # change to ddim_sample_images here to enable DDIM |
|
|
|
ddim_sample_images( # change to ddim_sample_images here to enable DDIM |
|
|
|
model=model, |
|
|
|
image_size=IMAGE_SIZE, |
|
|
|
batch_size=16, |
|
|
|
|