You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
173 lines
5.1 KiB
173 lines
5.1 KiB
from typing import Dict, Union
|
|
import torch
|
|
from tqdm import tqdm
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from train import extract, get_diffusion_params
|
|
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS
|
|
from model import UNet
|
|
|
|
torch.manual_seed(1)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
@torch.no_grad()
|
|
def ddpm_sample(
|
|
model: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
t: torch.Tensor,
|
|
params: Dict[str, torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""Sample from the model at timestep t"""
|
|
predicted_noise = model(x, t)
|
|
|
|
one_over_alphas = extract(params["one_over_alphas"], t, x.shape)
|
|
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape)
|
|
|
|
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise)
|
|
|
|
posterior_variance = extract(params["posterior_variance"], t, x.shape)
|
|
|
|
if t[0] > 0:
|
|
noise = torch.randn_like(x)
|
|
return pred_mean + torch.sqrt(posterior_variance) * noise
|
|
else:
|
|
return pred_mean
|
|
|
|
|
|
@torch.no_grad()
|
|
def ddim_sample(
|
|
model: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
t: torch.Tensor,
|
|
params: Dict[str, torch.Tensor],
|
|
) -> torch.Tensor:
|
|
"""Sample from the model in a non-markovian way (DDIM)"""
|
|
stride = TIMESTEPS // DDIM_TIMESTEPS
|
|
t_prev = t - stride
|
|
predicted_noise = model(x, t)
|
|
|
|
alphas_prod = extract(params["alphas_cumprod"], t, x.shape)
|
|
valid_mask = (t_prev >= 0).view(-1, 1, 1, 1)
|
|
safe_t_prev = torch.maximum(t_prev, torch.tensor(0, device=device))
|
|
alphas_prod_prev = extract(params["alphas_cumprod"], safe_t_prev, x.shape)
|
|
alphas_prod_prev = torch.where(
|
|
valid_mask, alphas_prod_prev, torch.ones_like(alphas_prod_prev)
|
|
)
|
|
|
|
sigma = extract(params["ddim_sigma"], t, x.shape)
|
|
|
|
pred_x0 = (x - (1 - alphas_prod).sqrt() * predicted_noise) / alphas_prod.sqrt()
|
|
|
|
pred = (
|
|
alphas_prod_prev.sqrt() * pred_x0
|
|
+ (1.0 - alphas_prod_prev).sqrt() * predicted_noise
|
|
)
|
|
|
|
if t[0] > 0:
|
|
noise = torch.randn_like(x)
|
|
pred = pred + noise * sigma
|
|
|
|
return pred
|
|
|
|
|
|
@torch.no_grad()
|
|
def ddpm_sample_images(
|
|
model: torch.nn.Module,
|
|
image_size: int,
|
|
batch_size: int,
|
|
channels: int,
|
|
device: torch.device,
|
|
params: Dict[str, torch.Tensor],
|
|
):
|
|
"""Generate new images using the trained model"""
|
|
x = torch.randn(batch_size, channels, image_size, image_size).to(device)
|
|
|
|
for t in tqdm(reversed(range(TIMESTEPS)), desc="DDPM Sampling", total=TIMESTEPS):
|
|
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
|
|
x = ddpm_sample(model, x, t_batch, params)
|
|
if t % 100 == 0:
|
|
show_images(x)
|
|
|
|
if x.isnan().any():
|
|
raise ValueError(f"NaN detected in image at timestep {t}")
|
|
|
|
return x
|
|
|
|
|
|
def get_ddim_timesteps(
|
|
total_timesteps: int, num_sampling_timesteps: int
|
|
) -> torch.Tensor:
|
|
"""Gets the timesteps used for the DDIM process."""
|
|
assert total_timesteps % num_sampling_timesteps == 0
|
|
stride = total_timesteps // num_sampling_timesteps
|
|
timesteps = torch.arange(0, total_timesteps, stride)
|
|
return timesteps.flip(0)
|
|
|
|
|
|
@torch.no_grad()
|
|
def ddim_sample_images(
|
|
model: torch.nn.Module,
|
|
image_size: int,
|
|
batch_size: int,
|
|
channels: int,
|
|
device: torch.device,
|
|
params: Dict[str, torch.Tensor],
|
|
):
|
|
"""Generate new images using the trained model"""
|
|
x = torch.randn(batch_size, channels, image_size, image_size).to(device)
|
|
|
|
timesteps = get_ddim_timesteps(TIMESTEPS, DDIM_TIMESTEPS)
|
|
|
|
for i in tqdm(range(len(timesteps) - 1), desc="DDIM Sampling"):
|
|
t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long)
|
|
x_before = x.clone()
|
|
x = ddim_sample(model, x, t, params)
|
|
print(f"Step {i}, max diff: {(x - x_before).abs().max().item()}")
|
|
|
|
if x.isnan().any():
|
|
raise ValueError(f"NaN detected at timestep {timesteps[i]}")
|
|
|
|
if i % 10 == 0:
|
|
show_images(x)
|
|
return x
|
|
|
|
|
|
def show_images(images: Union[torch.Tensor, np.array], title=""):
|
|
"""Display a batch of images in a grid"""
|
|
if isinstance(images, torch.Tensor):
|
|
images = images.detach().cpu().numpy()
|
|
|
|
images = (images * 0.25 + 0.5).clip(0, 1)
|
|
for idx in range(min(16, len(images))):
|
|
plt.subplot(4, 4, idx + 1)
|
|
plt.imshow(np.transpose(images[idx], (1, 2, 0)))
|
|
plt.axis("off")
|
|
plt.suptitle(title)
|
|
plt.draw()
|
|
plt.pause(0.001)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
plt.figure(figsize=(10, 10))
|
|
|
|
params = get_diffusion_params(TIMESTEPS, device, eta=0.0)
|
|
|
|
model = UNet(32, TIMESTEPS).to(device)
|
|
model.load_state_dict(torch.load("model.pkl", weights_only=True))
|
|
|
|
model.eval()
|
|
generated_images = (
|
|
ddpm_sample_images( # change to ddim_sample_images here to enable DDIM
|
|
model=model,
|
|
image_size=IMAGE_SIZE,
|
|
batch_size=16,
|
|
channels=CHANNELS,
|
|
device=device,
|
|
params=params,
|
|
)
|
|
)
|
|
show_images(generated_images, title="Generated Images")
|
|
|
|
# Keep the plot open after generation is finished
|
|
plt.show()
|
|
|