5 changed files with 193 additions and 98 deletions
@ -1,94 +0,0 @@ |
|||||
from typing import Dict, Union |
|
||||
import torch |
|
||||
from tqdm import tqdm |
|
||||
import matplotlib.pyplot as plt |
|
||||
import numpy as np |
|
||||
from train import extract, get_diffusion_params |
|
||||
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS |
|
||||
from model import UNet |
|
||||
|
|
||||
torch.manual_seed(1) |
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
||||
|
|
||||
|
|
||||
@torch.no_grad() |
|
||||
def p_sample( |
|
||||
model, x: torch.Tensor, t: torch.Tensor, params: Dict[str, torch.Tensor] |
|
||||
) -> torch.Tensor: |
|
||||
"""Sample from the model at timestep t""" |
|
||||
predicted_noise = model(x, t) |
|
||||
|
|
||||
one_over_alphas = extract(params["one_over_alphas"], t, x.shape) |
|
||||
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) |
|
||||
|
|
||||
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) |
|
||||
|
|
||||
posterior_variance = extract(params["posterior_variance"], t, x.shape) |
|
||||
|
|
||||
if t[0] > 0: |
|
||||
noise = torch.randn_like(x) |
|
||||
return pred_mean + torch.sqrt(posterior_variance) * noise |
|
||||
else: |
|
||||
return pred_mean |
|
||||
|
|
||||
|
|
||||
@torch.no_grad() |
|
||||
def sample_images( |
|
||||
model: torch.nn.Module, |
|
||||
image_size: int, |
|
||||
batch_size: int, |
|
||||
channels: int, |
|
||||
device: torch.device, |
|
||||
params: Dict[str, torch.Tensor], |
|
||||
): |
|
||||
"""Generate new images using the trained model""" |
|
||||
x = torch.randn(batch_size, channels, image_size, image_size).to(device) |
|
||||
|
|
||||
for t in tqdm(reversed(range(TIMESTEPS)), desc="Sampling", total=TIMESTEPS): |
|
||||
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) |
|
||||
x = p_sample(model, x, t_batch, params) |
|
||||
if t % 100 == 0: |
|
||||
show_images(x) |
|
||||
|
|
||||
if x.isnan().any(): |
|
||||
raise ValueError(f"NaN detected in image at timestep {t}") |
|
||||
|
|
||||
return x |
|
||||
|
|
||||
|
|
||||
def show_images(images: Union[torch.Tensor, np.array], title=""): |
|
||||
"""Display a batch of images in a grid""" |
|
||||
if isinstance(images, torch.Tensor): |
|
||||
images = images.detach().cpu().numpy() |
|
||||
|
|
||||
images = (images * 0.25 + 0.5).clip(0, 1) |
|
||||
for idx in range(min(16, len(images))): |
|
||||
plt.subplot(4, 4, idx + 1) |
|
||||
plt.imshow(np.transpose(images[idx], (1, 2, 0))) |
|
||||
plt.axis("off") |
|
||||
plt.suptitle(title) |
|
||||
plt.draw() |
|
||||
plt.pause(0.001) |
|
||||
|
|
||||
|
|
||||
if __name__ == "__main__": |
|
||||
plt.figure(figsize=(10, 10)) |
|
||||
|
|
||||
params = get_diffusion_params(TIMESTEPS, device) |
|
||||
|
|
||||
model = UNet(32, TIMESTEPS).to(device) |
|
||||
model.load_state_dict(torch.load("model.pkl", weights_only=True)) |
|
||||
|
|
||||
model.eval() |
|
||||
generated_images = sample_images( |
|
||||
model=model, |
|
||||
image_size=IMAGE_SIZE, |
|
||||
batch_size=16, |
|
||||
channels=CHANNELS, |
|
||||
device=device, |
|
||||
params=params, |
|
||||
) |
|
||||
show_images(generated_images, title="Generated Images") |
|
||||
|
|
||||
# Keep the plot open after generation is finished |
|
||||
plt.show() |
|
||||
@ -0,0 +1,173 @@ |
|||||
|
from typing import Dict, Union |
||||
|
import torch |
||||
|
from tqdm import tqdm |
||||
|
import matplotlib.pyplot as plt |
||||
|
import numpy as np |
||||
|
from train import extract, get_diffusion_params |
||||
|
from train import TIMESTEPS, IMAGE_SIZE, CHANNELS, DDIM_TIMESTEPS |
||||
|
from model import UNet |
||||
|
|
||||
|
torch.manual_seed(1) |
||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def ddpm_sample( |
||||
|
model: torch.nn.Module, |
||||
|
x: torch.Tensor, |
||||
|
t: torch.Tensor, |
||||
|
params: Dict[str, torch.Tensor], |
||||
|
) -> torch.Tensor: |
||||
|
"""Sample from the model at timestep t""" |
||||
|
predicted_noise = model(x, t) |
||||
|
|
||||
|
one_over_alphas = extract(params["one_over_alphas"], t, x.shape) |
||||
|
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) |
||||
|
|
||||
|
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) |
||||
|
|
||||
|
posterior_variance = extract(params["posterior_variance"], t, x.shape) |
||||
|
|
||||
|
if t[0] > 0: |
||||
|
noise = torch.randn_like(x) |
||||
|
return pred_mean + torch.sqrt(posterior_variance) * noise |
||||
|
else: |
||||
|
return pred_mean |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def ddim_sample( |
||||
|
model: torch.nn.Module, |
||||
|
x: torch.Tensor, |
||||
|
t: torch.Tensor, |
||||
|
params: Dict[str, torch.Tensor], |
||||
|
) -> torch.Tensor: |
||||
|
"""Sample from the model in a non-markovian way (DDIM)""" |
||||
|
stride = TIMESTEPS // DDIM_TIMESTEPS |
||||
|
t_prev = t - stride |
||||
|
predicted_noise = model(x, t) |
||||
|
|
||||
|
alphas_prod = extract(params["alphas_cumprod"], t, x.shape) |
||||
|
valid_mask = (t_prev >= 0).view(-1, 1, 1, 1) |
||||
|
safe_t_prev = torch.maximum(t_prev, torch.tensor(0, device=device)) |
||||
|
alphas_prod_prev = extract(params["alphas_cumprod"], safe_t_prev, x.shape) |
||||
|
alphas_prod_prev = torch.where( |
||||
|
valid_mask, alphas_prod_prev, torch.ones_like(alphas_prod_prev) |
||||
|
) |
||||
|
|
||||
|
sigma = extract(params["ddim_sigma"], t, x.shape) |
||||
|
|
||||
|
pred_x0 = (x - (1 - alphas_prod).sqrt() * predicted_noise) / alphas_prod.sqrt() |
||||
|
|
||||
|
pred = ( |
||||
|
alphas_prod_prev.sqrt() * pred_x0 |
||||
|
+ (1.0 - alphas_prod_prev).sqrt() * predicted_noise |
||||
|
) |
||||
|
|
||||
|
if t[0] > 0: |
||||
|
noise = torch.randn_like(x) |
||||
|
pred = pred + noise * sigma |
||||
|
|
||||
|
return pred |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def ddpm_sample_images( |
||||
|
model: torch.nn.Module, |
||||
|
image_size: int, |
||||
|
batch_size: int, |
||||
|
channels: int, |
||||
|
device: torch.device, |
||||
|
params: Dict[str, torch.Tensor], |
||||
|
): |
||||
|
"""Generate new images using the trained model""" |
||||
|
x = torch.randn(batch_size, channels, image_size, image_size).to(device) |
||||
|
|
||||
|
for t in tqdm(reversed(range(TIMESTEPS)), desc="DDPM Sampling", total=TIMESTEPS): |
||||
|
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) |
||||
|
x = ddpm_sample(model, x, t_batch, params) |
||||
|
if t % 100 == 0: |
||||
|
show_images(x) |
||||
|
|
||||
|
if x.isnan().any(): |
||||
|
raise ValueError(f"NaN detected in image at timestep {t}") |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
|
||||
|
def get_ddim_timesteps( |
||||
|
total_timesteps: int, num_sampling_timesteps: int |
||||
|
) -> torch.Tensor: |
||||
|
"""Gets the timesteps used for the DDIM process.""" |
||||
|
assert total_timesteps % num_sampling_timesteps == 0 |
||||
|
stride = total_timesteps // num_sampling_timesteps |
||||
|
timesteps = torch.arange(0, total_timesteps, stride) |
||||
|
return timesteps.flip(0) |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def ddim_sample_images( |
||||
|
model: torch.nn.Module, |
||||
|
image_size: int, |
||||
|
batch_size: int, |
||||
|
channels: int, |
||||
|
device: torch.device, |
||||
|
params: Dict[str, torch.Tensor], |
||||
|
): |
||||
|
"""Generate new images using the trained model""" |
||||
|
x = torch.randn(batch_size, channels, image_size, image_size).to(device) |
||||
|
|
||||
|
timesteps = get_ddim_timesteps(TIMESTEPS, DDIM_TIMESTEPS) |
||||
|
|
||||
|
for i in tqdm(range(len(timesteps) - 1), desc="DDIM Sampling"): |
||||
|
t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long) |
||||
|
x_before = x.clone() |
||||
|
x = ddim_sample(model, x, t, params) |
||||
|
print(f"Step {i}, max diff: {(x - x_before).abs().max().item()}") |
||||
|
|
||||
|
if x.isnan().any(): |
||||
|
raise ValueError(f"NaN detected at timestep {timesteps[i]}") |
||||
|
|
||||
|
if i % 10 == 0: |
||||
|
show_images(x) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
def show_images(images: Union[torch.Tensor, np.array], title=""): |
||||
|
"""Display a batch of images in a grid""" |
||||
|
if isinstance(images, torch.Tensor): |
||||
|
images = images.detach().cpu().numpy() |
||||
|
|
||||
|
images = (images * 0.25 + 0.5).clip(0, 1) |
||||
|
for idx in range(min(16, len(images))): |
||||
|
plt.subplot(4, 4, idx + 1) |
||||
|
plt.imshow(np.transpose(images[idx], (1, 2, 0))) |
||||
|
plt.axis("off") |
||||
|
plt.suptitle(title) |
||||
|
plt.draw() |
||||
|
plt.pause(0.001) |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
plt.figure(figsize=(10, 10)) |
||||
|
|
||||
|
params = get_diffusion_params(TIMESTEPS, device, eta=0.0) |
||||
|
|
||||
|
model = UNet(32, TIMESTEPS).to(device) |
||||
|
model.load_state_dict(torch.load("model.pkl", weights_only=True)) |
||||
|
|
||||
|
model.eval() |
||||
|
generated_images = ( |
||||
|
ddpm_sample_images( # change to ddim_sample_images here to enable DDIM |
||||
|
model=model, |
||||
|
image_size=IMAGE_SIZE, |
||||
|
batch_size=16, |
||||
|
channels=CHANNELS, |
||||
|
device=device, |
||||
|
params=params, |
||||
|
) |
||||
|
) |
||||
|
show_images(generated_images, title="Generated Images") |
||||
|
|
||||
|
# Keep the plot open after generation is finished |
||||
|
plt.show() |
||||
Loading…
Reference in new issue