You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

173 lines
5.1 KiB

from typing import Dict, Union
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from train import extract, get_diffusion_params
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS
from model import UNet
torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def ddpm_sample(
model: torch.nn.Module,
x: torch.Tensor,
t: torch.Tensor,
params: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""Sample from the model at timestep t"""
predicted_noise = model(x, t)
one_over_alphas = extract(params["one_over_alphas"], t, x.shape)
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape)
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise)
posterior_variance = extract(params["posterior_variance"], t, x.shape)
if t[0] > 0:
noise = torch.randn_like(x)
return pred_mean + torch.sqrt(posterior_variance) * noise
else:
return pred_mean
@torch.no_grad()
def ddim_sample(
model: torch.nn.Module,
x: torch.Tensor,
t: torch.Tensor,
params: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""Sample from the model in a non-markovian way (DDIM)"""
stride = TIMESTEPS // DDIM_TIMESTEPS
t_prev = t - stride
predicted_noise = model(x, t)
alphas_prod = extract(params["alphas_cumprod"], t, x.shape)
valid_mask = (t_prev >= 0).view(-1, 1, 1, 1)
safe_t_prev = torch.maximum(t_prev, torch.tensor(0, device=device))
alphas_prod_prev = extract(params["alphas_cumprod"], safe_t_prev, x.shape)
alphas_prod_prev = torch.where(
valid_mask, alphas_prod_prev, torch.ones_like(alphas_prod_prev)
)
sigma = extract(params["ddim_sigma"], t, x.shape)
pred_x0 = (x - (1 - alphas_prod).sqrt() * predicted_noise) / alphas_prod.sqrt()
pred = (
alphas_prod_prev.sqrt() * pred_x0
+ (1.0 - alphas_prod_prev).sqrt() * predicted_noise
)
if t[0] > 0:
noise = torch.randn_like(x)
pred = pred + noise * sigma
return pred
@torch.no_grad()
def ddpm_sample_images(
model: torch.nn.Module,
image_size: int,
batch_size: int,
channels: int,
device: torch.device,
params: Dict[str, torch.Tensor],
):
"""Generate new images using the trained model"""
x = torch.randn(batch_size, channels, image_size, image_size).to(device)
for t in tqdm(reversed(range(TIMESTEPS)), desc="DDPM Sampling", total=TIMESTEPS):
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
x = ddpm_sample(model, x, t_batch, params)
if t % 100 == 0:
show_images(x)
if x.isnan().any():
raise ValueError(f"NaN detected in image at timestep {t}")
return x
def get_ddim_timesteps(
total_timesteps: int, num_sampling_timesteps: int
) -> torch.Tensor:
"""Gets the timesteps used for the DDIM process."""
assert total_timesteps % num_sampling_timesteps == 0
stride = total_timesteps // num_sampling_timesteps
timesteps = torch.arange(0, total_timesteps, stride)
return timesteps.flip(0)
@torch.no_grad()
def ddim_sample_images(
model: torch.nn.Module,
image_size: int,
batch_size: int,
channels: int,
device: torch.device,
params: Dict[str, torch.Tensor],
):
"""Generate new images using the trained model"""
x = torch.randn(batch_size, channels, image_size, image_size).to(device)
timesteps = get_ddim_timesteps(TIMESTEPS, DDIM_TIMESTEPS)
for i in tqdm(range(len(timesteps) - 1), desc="DDIM Sampling"):
t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long)
x_before = x.clone()
x = ddim_sample(model, x, t, params)
print(f"Step {i}, max diff: {(x - x_before).abs().max().item()}")
if x.isnan().any():
raise ValueError(f"NaN detected at timestep {timesteps[i]}")
if i % 10 == 0:
show_images(x)
return x
def show_images(images: Union[torch.Tensor, np.array], title=""):
"""Display a batch of images in a grid"""
if isinstance(images, torch.Tensor):
images = images.detach().cpu().numpy()
images = (images * 0.25 + 0.5).clip(0, 1)
for idx in range(min(16, len(images))):
plt.subplot(4, 4, idx + 1)
plt.imshow(np.transpose(images[idx], (1, 2, 0)))
plt.axis("off")
plt.suptitle(title)
plt.draw()
plt.pause(0.001)
if __name__ == "__main__":
plt.figure(figsize=(10, 10))
params = get_diffusion_params(TIMESTEPS, device, eta=0.0)
model = UNet(32, TIMESTEPS).to(device)
model.load_state_dict(torch.load("model.pkl", weights_only=True))
model.eval()
generated_images = (
ddpm_sample_images( # change to ddim_sample_images here to enable DDIM
model=model,
image_size=IMAGE_SIZE,
batch_size=16,
channels=CHANNELS,
device=device,
params=params,
)
)
show_images(generated_images, title="Generated Images")
# Keep the plot open after generation is finished
plt.show()