5 changed files with 193 additions and 98 deletions
@ -1,94 +0,0 @@ |
|||
from typing import Dict, Union |
|||
import torch |
|||
from tqdm import tqdm |
|||
import matplotlib.pyplot as plt |
|||
import numpy as np |
|||
from train import extract, get_diffusion_params |
|||
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS |
|||
from model import UNet |
|||
|
|||
torch.manual_seed(1) |
|||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def p_sample( |
|||
model, x: torch.Tensor, t: torch.Tensor, params: Dict[str, torch.Tensor] |
|||
) -> torch.Tensor: |
|||
"""Sample from the model at timestep t""" |
|||
predicted_noise = model(x, t) |
|||
|
|||
one_over_alphas = extract(params["one_over_alphas"], t, x.shape) |
|||
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) |
|||
|
|||
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) |
|||
|
|||
posterior_variance = extract(params["posterior_variance"], t, x.shape) |
|||
|
|||
if t[0] > 0: |
|||
noise = torch.randn_like(x) |
|||
return pred_mean + torch.sqrt(posterior_variance) * noise |
|||
else: |
|||
return pred_mean |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def sample_images( |
|||
model: torch.nn.Module, |
|||
image_size: int, |
|||
batch_size: int, |
|||
channels: int, |
|||
device: torch.device, |
|||
params: Dict[str, torch.Tensor], |
|||
): |
|||
"""Generate new images using the trained model""" |
|||
x = torch.randn(batch_size, channels, image_size, image_size).to(device) |
|||
|
|||
for t in tqdm(reversed(range(TIMESTEPS)), desc="Sampling", total=TIMESTEPS): |
|||
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) |
|||
x = p_sample(model, x, t_batch, params) |
|||
if t % 100 == 0: |
|||
show_images(x) |
|||
|
|||
if x.isnan().any(): |
|||
raise ValueError(f"NaN detected in image at timestep {t}") |
|||
|
|||
return x |
|||
|
|||
|
|||
def show_images(images: Union[torch.Tensor, np.array], title=""): |
|||
"""Display a batch of images in a grid""" |
|||
if isinstance(images, torch.Tensor): |
|||
images = images.detach().cpu().numpy() |
|||
|
|||
images = (images * 0.25 + 0.5).clip(0, 1) |
|||
for idx in range(min(16, len(images))): |
|||
plt.subplot(4, 4, idx + 1) |
|||
plt.imshow(np.transpose(images[idx], (1, 2, 0))) |
|||
plt.axis("off") |
|||
plt.suptitle(title) |
|||
plt.draw() |
|||
plt.pause(0.001) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
plt.figure(figsize=(10, 10)) |
|||
|
|||
params = get_diffusion_params(TIMESTEPS, device) |
|||
|
|||
model = UNet(32, TIMESTEPS).to(device) |
|||
model.load_state_dict(torch.load("model.pkl", weights_only=True)) |
|||
|
|||
model.eval() |
|||
generated_images = sample_images( |
|||
model=model, |
|||
image_size=IMAGE_SIZE, |
|||
batch_size=16, |
|||
channels=CHANNELS, |
|||
device=device, |
|||
params=params, |
|||
) |
|||
show_images(generated_images, title="Generated Images") |
|||
|
|||
# Keep the plot open after generation is finished |
|||
plt.show() |
|||
@ -0,0 +1,173 @@ |
|||
from typing import Dict, Union |
|||
import torch |
|||
from tqdm import tqdm |
|||
import matplotlib.pyplot as plt |
|||
import numpy as np |
|||
from train import extract, get_diffusion_params |
|||
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS |
|||
from model import UNet |
|||
|
|||
torch.manual_seed(1) |
|||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def ddpm_sample( |
|||
model: torch.nn.Module, |
|||
x: torch.Tensor, |
|||
t: torch.Tensor, |
|||
params: Dict[str, torch.Tensor], |
|||
) -> torch.Tensor: |
|||
"""Sample from the model at timestep t""" |
|||
predicted_noise = model(x, t) |
|||
|
|||
one_over_alphas = extract(params["one_over_alphas"], t, x.shape) |
|||
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) |
|||
|
|||
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) |
|||
|
|||
posterior_variance = extract(params["posterior_variance"], t, x.shape) |
|||
|
|||
if t[0] > 0: |
|||
noise = torch.randn_like(x) |
|||
return pred_mean + torch.sqrt(posterior_variance) * noise |
|||
else: |
|||
return pred_mean |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def ddim_sample( |
|||
model: torch.nn.Module, |
|||
x: torch.Tensor, |
|||
t: torch.Tensor, |
|||
params: Dict[str, torch.Tensor], |
|||
) -> torch.Tensor: |
|||
"""Sample from the model in a non-markovian way (DDIM)""" |
|||
stride = TIMESTEPS // DDIM_TIMESTEPS |
|||
t_prev = t - stride |
|||
predicted_noise = model(x, t) |
|||
|
|||
alphas_prod = extract(params["alphas_cumprod"], t, x.shape) |
|||
valid_mask = (t_prev >= 0).view(-1, 1, 1, 1) |
|||
safe_t_prev = torch.maximum(t_prev, torch.tensor(0, device=device)) |
|||
alphas_prod_prev = extract(params["alphas_cumprod"], safe_t_prev, x.shape) |
|||
alphas_prod_prev = torch.where( |
|||
valid_mask, alphas_prod_prev, torch.ones_like(alphas_prod_prev) |
|||
) |
|||
|
|||
sigma = extract(params["ddim_sigma"], t, x.shape) |
|||
|
|||
pred_x0 = (x - (1 - alphas_prod).sqrt() * predicted_noise) / alphas_prod.sqrt() |
|||
|
|||
pred = ( |
|||
alphas_prod_prev.sqrt() * pred_x0 |
|||
+ (1.0 - alphas_prod_prev).sqrt() * predicted_noise |
|||
) |
|||
|
|||
if t[0] > 0: |
|||
noise = torch.randn_like(x) |
|||
pred = pred + noise * sigma |
|||
|
|||
return pred |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def ddpm_sample_images( |
|||
model: torch.nn.Module, |
|||
image_size: int, |
|||
batch_size: int, |
|||
channels: int, |
|||
device: torch.device, |
|||
params: Dict[str, torch.Tensor], |
|||
): |
|||
"""Generate new images using the trained model""" |
|||
x = torch.randn(batch_size, channels, image_size, image_size).to(device) |
|||
|
|||
for t in tqdm(reversed(range(TIMESTEPS)), desc="DDPM Sampling", total=TIMESTEPS): |
|||
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) |
|||
x = ddpm_sample(model, x, t_batch, params) |
|||
if t % 100 == 0: |
|||
show_images(x) |
|||
|
|||
if x.isnan().any(): |
|||
raise ValueError(f"NaN detected in image at timestep {t}") |
|||
|
|||
return x |
|||
|
|||
|
|||
def get_ddim_timesteps( |
|||
total_timesteps: int, num_sampling_timesteps: int |
|||
) -> torch.Tensor: |
|||
"""Gets the timesteps used for the DDIM process.""" |
|||
assert total_timesteps % num_sampling_timesteps == 0 |
|||
stride = total_timesteps // num_sampling_timesteps |
|||
timesteps = torch.arange(0, total_timesteps, stride) |
|||
return timesteps.flip(0) |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def ddim_sample_images( |
|||
model: torch.nn.Module, |
|||
image_size: int, |
|||
batch_size: int, |
|||
channels: int, |
|||
device: torch.device, |
|||
params: Dict[str, torch.Tensor], |
|||
): |
|||
"""Generate new images using the trained model""" |
|||
x = torch.randn(batch_size, channels, image_size, image_size).to(device) |
|||
|
|||
timesteps = get_ddim_timesteps(TIMESTEPS, DDIM_TIMESTEPS) |
|||
|
|||
for i in tqdm(range(len(timesteps) - 1), desc="DDIM Sampling"): |
|||
t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long) |
|||
x_before = x.clone() |
|||
x = ddim_sample(model, x, t, params) |
|||
print(f"Step {i}, max diff: {(x - x_before).abs().max().item()}") |
|||
|
|||
if x.isnan().any(): |
|||
raise ValueError(f"NaN detected at timestep {timesteps[i]}") |
|||
|
|||
if i % 10 == 0: |
|||
show_images(x) |
|||
return x |
|||
|
|||
|
|||
def show_images(images: Union[torch.Tensor, np.array], title=""): |
|||
"""Display a batch of images in a grid""" |
|||
if isinstance(images, torch.Tensor): |
|||
images = images.detach().cpu().numpy() |
|||
|
|||
images = (images * 0.25 + 0.5).clip(0, 1) |
|||
for idx in range(min(16, len(images))): |
|||
plt.subplot(4, 4, idx + 1) |
|||
plt.imshow(np.transpose(images[idx], (1, 2, 0))) |
|||
plt.axis("off") |
|||
plt.suptitle(title) |
|||
plt.draw() |
|||
plt.pause(0.001) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
plt.figure(figsize=(10, 10)) |
|||
|
|||
params = get_diffusion_params(TIMESTEPS, device, eta=0.0) |
|||
|
|||
model = UNet(32, TIMESTEPS).to(device) |
|||
model.load_state_dict(torch.load("model.pkl", weights_only=True)) |
|||
|
|||
model.eval() |
|||
generated_images = ( |
|||
ddpm_sample_images( # change to ddim_sample_images here to enable DDIM |
|||
model=model, |
|||
image_size=IMAGE_SIZE, |
|||
batch_size=16, |
|||
channels=CHANNELS, |
|||
device=device, |
|||
params=params, |
|||
) |
|||
) |
|||
show_images(generated_images, title="Generated Images") |
|||
|
|||
# Keep the plot open after generation is finished |
|||
plt.show() |
|||
Loading…
Reference in new issue