Browse Source

refactor

master
CALVO GONZALEZ Ramon 10 months ago
parent
commit
d6a97ae1d7
  1. 176
      src/ddpm/diffusion.py
  2. 207
      src/ddpm/generate_circle_dataset.py
  3. 350
      src/ddpm/sample.py
  4. 200
      src/ddpm/train.py
  5. 329
      src/ddpm/utils.py
  6. 29
      src/ddpm/values.py

176
src/ddpm/diffusion.py

@ -0,0 +1,176 @@
from typing import Dict
import torch
from tqdm import tqdm
TIMESTEPS = 1000
DDIM_TIMESTEPS = 500
@torch.compile
@torch.no_grad()
def ddpm_sample(
model: torch.nn.Module,
x: torch.Tensor,
t: torch.Tensor,
params: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""Sample from the model at timestep t"""
predicted_noise = model(x, t)
one_over_alphas = extract(params["one_over_alphas"], t, x.shape)
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape)
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise)
posterior_variance = extract(params["posterior_variance"], t, x.shape)
if t[0] > 0:
noise = torch.randn_like(x)
return pred_mean + torch.sqrt(posterior_variance) * noise
else:
return pred_mean
@torch.no_grad()
def ddim_sample(
model: torch.nn.Module,
x: torch.Tensor,
t: torch.Tensor,
params: Dict[str, torch.Tensor],
device,
) -> torch.Tensor:
"""Sample from the model in a non-markovian way (DDIM)"""
device = next(model.parameters()).device
stride = TIMESTEPS // DDIM_TIMESTEPS
t_prev = t - stride
predicted_noise = model(x, t)
alphas_prod = extract(params["alphas_cumprod"], t, x.shape)
valid_mask = (t_prev >= 0).view(-1, 1, 1, 1)
safe_t_prev = torch.maximum(t_prev, torch.tensor(0, device=device))
alphas_prod_prev = extract(params["alphas_cumprod"], safe_t_prev, x.shape)
alphas_prod_prev = torch.where(
valid_mask, alphas_prod_prev, torch.ones_like(alphas_prod_prev)
)
sigma = extract(params["ddim_sigma"], t, x.shape)
pred_x0 = (x - (1 - alphas_prod).sqrt() * predicted_noise) / alphas_prod.sqrt()
pred = (
alphas_prod_prev.sqrt() * pred_x0
+ (1.0 - alphas_prod_prev).sqrt() * predicted_noise
)
if t[0] > 0:
noise = torch.randn_like(x)
pred = pred + noise * sigma
return pred
@torch.no_grad()
def ddpm_sample_images(
model: torch.nn.Module,
image_size: int,
batch_size: int,
channels: int,
device: torch.device,
params: Dict[str, torch.Tensor],
):
"""Generate new images using the trained model"""
x = torch.randn(batch_size, channels, image_size, image_size).to(device)
for t in tqdm(
reversed(range(TIMESTEPS)), desc="DDPM Sampling", total=TIMESTEPS, leave=False
):
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
x = ddpm_sample(model, x, t_batch, params)
if x.isnan().any():
raise ValueError(f"NaN detected in image at timestep {t}")
return x
def get_ddim_timesteps(
total_timesteps: int, num_sampling_timesteps: int
) -> torch.Tensor:
"""Gets the timesteps used for the DDIM process."""
assert total_timesteps % num_sampling_timesteps == 0
stride = total_timesteps // num_sampling_timesteps
timesteps = torch.arange(0, total_timesteps, stride)
return timesteps.flip(0)
@torch.no_grad()
def ddim_sample_images(
model: torch.nn.Module,
image_size: int,
batch_size: int,
channels: int,
device: torch.device,
params: Dict[str, torch.Tensor],
):
"""Generate new images using the trained model"""
x = torch.randn(batch_size, channels, image_size, image_size).to(device)
timesteps = get_ddim_timesteps(TIMESTEPS, DDIM_TIMESTEPS)
for i in tqdm(range(len(timesteps) - 1), desc="DDIM Sampling"):
t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long)
x = ddim_sample(model, x, t, params)
if x.isnan().any():
raise ValueError(f"NaN detected at timestep {timesteps[i]}")
return x
def extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Tensor.shape):
"""Extract coefficients at specified timesteps t"""
batch_size = t.shape[0]
out = a.gather(-1, t)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def get_diffusion_params(
timesteps: int,
device: torch.device,
ddim_timesteps: int = DDIM_TIMESTEPS,
eta=0.0,
) -> Dict[str, torch.Tensor]:
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
betas = linear_beta_schedule(timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
one_over_alphas = 1.0 / torch.sqrt(alphas)
posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
ddim_sigma = eta * torch.sqrt(
(1.0 - alphas_cumprod_prev)
/ (1.0 - alphas_cumprod)
* (1 - alphas_cumprod / alphas_cumprod_prev)
)
return {
# DDPM Parameters
"betas": betas.to(device),
"alphas_cumprod": alphas_cumprod.to(device),
"posterior_variance": posterior_variance.to(device),
"one_over_alphas": one_over_alphas.to(device),
"posterior_mean_coef": posterior_mean_coef.to(device),
# DDIM Parameters
"ddim_sigma": ddim_sigma.to(device),
}

207
src/ddpm/generate_circle_dataset.py

@ -1,5 +1,5 @@
import argparse
import numpy as np
from scipy.ndimage import label, center_of_mass
from PIL import Image, ImageDraw
from tqdm import tqdm
@ -8,17 +8,25 @@ import os
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat
import matplotlib.pyplot as plt
import matplotlib
from utils import visualize_samples
RED = (0xCC, 0x24, 0x1D)
GREEN = (0x98, 0x97, 0x1A)
BLUE = (0x45, 0x85, 0x88)
BACKGROUND = (0x50, 0x49, 0x45)
from values import GREEN, BLUE, WHITE, BACKGROUND, EXPERIMENTS
matplotlib.use("Agg")
def create_sample_antialiased(
id: int, image_size: int, distance: int, radius: int, delta: int, scale=4
id: int,
image_size: int,
distance: int,
radius: int,
delta: int,
scale=4,
enable_constraint: bool = True,
enable_rectangle: bool = False,
rectangle_thickness: int = 0,
):
# Scale up the image dimensions
high_res_size = image_size * scale
@ -37,9 +45,47 @@ def create_sample_antialiased(
x1, y1 = np.random.randint(
low=high_res_radius, high=high_res_size - high_res_radius, size=2
)
if not enable_constraint:
break
dist = np.sqrt((x0 - x1) ** 2 + (y0 - y1) ** 2)
# Draw anti-aliased circles using PIL's ellipse method
if enable_rectangle:
# Compute the vector from circle0 to circle1.
dx = x1 - x0
dy = y1 - y0
d = np.sqrt(dx**2 + dy**2)
if d != 0:
ux = dx / d
uy = dy / d
else:
ux, uy = 0, 0
# Extend endpoints to fully enclose both circles:
start_x = x0 - high_res_radius * ux
start_y = y0 - high_res_radius * uy
end_x = x1 + high_res_radius * ux
end_y = y1 + high_res_radius * uy
# Ensure the rectangle is thick enough to enclose the entire circles.
thickness = max(rectangle_thickness * scale, 2 * high_res_radius)
half_thickness = thickness / 2.0
# Compute perpendicular vector to (ux, uy)
perp_x = -uy
perp_y = ux
# Compute the four corners of the rectangle
p1 = (start_x + half_thickness * perp_x, start_y + half_thickness * perp_y)
p2 = (start_x - half_thickness * perp_x, start_y - half_thickness * perp_y)
p3 = (end_x - half_thickness * perp_x, end_y - half_thickness * perp_y)
p4 = (end_x + half_thickness * perp_x, end_y + half_thickness * perp_y)
# Draw the white rectangle (as a polygon)
draw.polygon([p1, p2, p3, p4], fill=WHITE)
# Draw anti-aliased circles using PIL's circle method (or ellipse if needed)
draw.circle((x0, y0), high_res_radius, fill=GREEN)
draw.circle((x1, y1), high_res_radius, fill=BLUE)
@ -48,71 +94,16 @@ def create_sample_antialiased(
return id, np.array(im)
def create_sample(id: int, image_size: int, distance: int, radius: int, delta: int):
# Create a blank image
img = np.full(
shape=(image_size, image_size, 3), fill_value=BACKGROUND, dtype=np.uint8
)
# Compute random centers until they are inside the distance range
dist = float("inf")
while (dist < distance - delta) or (dist > distance + delta):
x0, y0 = np.random.randint(
low=radius, high=image_size - radius, size=2, dtype=np.int32
)
x1, y1 = np.random.randint(
low=radius, high=image_size - radius, size=2, dtype=np.int32
)
dist = np.sqrt((x0 - x1) ** 2 + (y0 - y1) ** 2)
# Draw the circles
xx, yy = np.mgrid[:image_size, :image_size]
# Create boolean masks for the circles based on the radius
mask0 = (xx - x0) ** 2 + (yy - y0) ** 2 <= radius**2
mask1 = (xx - x1) ** 2 + (yy - y1) ** 2 <= radius**2
# Apply the colors to the pixels where the mask is True
img[mask0] = GREEN
img[mask1] = BLUE
return id, img
def detect_circle_centers(image, background=BACKGROUND, threshold=30):
"""
Detects centers of circles in an image by finding connected components
that differ from the background color.
Args:
image (np.ndarray): The image array with shape (H, W, 3).
background (np.ndarray): The background color to ignore.
threshold (int): The minimum per-channel difference to consider a pixel
as part of a circle.
Returns:
centers (list of tuples): List of (row, col) coordinates for each detected circle.
"""
# Compute the absolute difference from the background for each pixel.
diff = np.abs(image.astype(np.int16) - np.array(background).astype(np.int16))
# Create a mask where any channel difference exceeds the threshold.
mask = np.any(diff > threshold, axis=-1)
# Label connected regions in the mask.
labeled, num_features = label(mask)
centers = []
# Compute the center of mass for each labeled region.
for i in range(1, num_features + 1):
center = center_of_mass(mask, labeled, i)
centers.append(center)
return centers
def generate_circle_dataset(num_samples, image_size, radius, distance, delta):
def generate_circle_dataset(
num_samples,
image_size,
radius,
distance,
delta,
enable_constraint: bool = True,
enable_rectangle: bool = False,
rectangle_thickness: int = 0,
):
"""
Generate a dataset of images with two circles (red and blue) and save as numpy tensors.
@ -132,57 +123,51 @@ def generate_circle_dataset(num_samples, image_size, radius, distance, delta):
repeat(distance),
repeat(radius),
repeat(delta),
repeat(4),
repeat(enable_constraint),
repeat(enable_rectangle),
repeat(rectangle_thickness),
chunksize=100,
):
yield i, sample
def visualize_samples(dataset: np.array, output_dir: str):
# Define the grid size (e.g., 5x5)
grid_size = 5
fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10))
for i in range(grid_size):
for j in range(grid_size):
idx = i * grid_size + j
if idx < len(dataset):
img = dataset[idx]
axes[i, j].imshow(img)
centers = detect_circle_centers(img)
# Plot each detected center in red. Note that center_of_mass returns (row, col)
for center in centers:
axes[i, j].scatter(center[1], center[0], c="red", s=20)
axes[i, j].axis("off")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "sample_grid.png"))
plt.close()
if __name__ == "__main__":
def main(args, experiment):
# Create output directory if it doesn't exist
total_samples = 1_000_000
image_size = 32
distance = 11
delta = 4
radius = 3
image_size = experiment["image_size"]
output_dir = "data/circle_dataset"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(args.output_dir, exist_ok=True)
dataset = np.empty((total_samples, image_size, image_size, 3), dtype=np.uint8)
dataset = np.empty((args.total_samples, image_size, image_size, 3), dtype=np.uint8)
iterator = generate_circle_dataset(
image_size=image_size,
num_samples=total_samples,
distance=distance,
delta=delta,
radius=radius,
num_samples=args.total_samples,
distance=experiment["distance"],
delta=experiment["delta"],
radius=experiment["radius"],
enable_constraint=True,
enable_rectangle=experiment["rectangle_thickness"] > 0,
rectangle_thickness=experiment["rectangle_thickness"],
)
for i, sample in tqdm(iterator, total=total_samples):
for i, sample in tqdm(iterator, total=args.total_samples):
dataset[i] = sample
visualize_samples(dataset, output_dir)
visualize_samples(
dataset, args.output_dir, f"sample_grid_{args.experiment_name}.png"
)
# Save the dataset
np.save(os.path.join(output_dir, "data32.npy"), dataset)
np.save(os.path.join(args.output_dir, f"data-{args.experiment_name}.npy"), dataset)
# np.savez_compressed(os.path.join(output_dir, "data.npy.npz"), dataset)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate a dataset of circles")
parser.add_argument(
"-e", "--experiment_name", required=True, type=str, choices=EXPERIMENTS.keys()
)
parser.add_argument("-o", "--output_dir", default="data/circle_dataset")
parser.add_argument("-t", "--total_samples", default=1_000_000, type=int)
args = parser.parse_args()
main(args, EXPERIMENTS[args.experiment_name])

350
src/ddpm/sample.py

@ -1,288 +1,23 @@
from typing import Dict
from concurrent.futures import ProcessPoolExecutor
import torch
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from train import extract, get_diffusion_params
from train import TIMESTEPS, IMAGE_SIZE, DDIM_TIMESTEPS
from train import get_diffusion_params
from train import TIMESTEPS, IMAGE_SIZE
from model import UNet
from generate_circle_dataset import detect_circle_centers
from generate_circle_dataset import (
visualize_samples,
)
from utils import compute_statistics, generate_synthetic_samples
matplotlib.use("Agg")
torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
GENERATE_SYNTHETIC_DATA = False
@torch.compile
@torch.no_grad()
def ddpm_sample(
model: torch.nn.Module,
x: torch.Tensor,
t: torch.Tensor,
params: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""Sample from the model at timestep t"""
predicted_noise = model(x, t)
one_over_alphas = extract(params["one_over_alphas"], t, x.shape)
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape)
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise)
posterior_variance = extract(params["posterior_variance"], t, x.shape)
if t[0] > 0:
noise = torch.randn_like(x)
return pred_mean + torch.sqrt(posterior_variance) * noise
else:
return pred_mean
@torch.no_grad()
def ddim_sample(
model: torch.nn.Module,
x: torch.Tensor,
t: torch.Tensor,
params: Dict[str, torch.Tensor],
) -> torch.Tensor:
"""Sample from the model in a non-markovian way (DDIM)"""
stride = TIMESTEPS // DDIM_TIMESTEPS
t_prev = t - stride
predicted_noise = model(x, t)
alphas_prod = extract(params["alphas_cumprod"], t, x.shape)
valid_mask = (t_prev >= 0).view(-1, 1, 1, 1)
safe_t_prev = torch.maximum(t_prev, torch.tensor(0, device=device))
alphas_prod_prev = extract(params["alphas_cumprod"], safe_t_prev, x.shape)
alphas_prod_prev = torch.where(
valid_mask, alphas_prod_prev, torch.ones_like(alphas_prod_prev)
)
sigma = extract(params["ddim_sigma"], t, x.shape)
pred_x0 = (x - (1 - alphas_prod).sqrt() * predicted_noise) / alphas_prod.sqrt()
pred = (
alphas_prod_prev.sqrt() * pred_x0
+ (1.0 - alphas_prod_prev).sqrt() * predicted_noise
)
if t[0] > 0:
noise = torch.randn_like(x)
pred = pred + noise * sigma
return pred
@torch.no_grad()
def ddpm_sample_images(
model: torch.nn.Module,
image_size: int,
batch_size: int,
channels: int,
device: torch.device,
params: Dict[str, torch.Tensor],
):
"""Generate new images using the trained model"""
x = torch.randn(batch_size, channels, image_size, image_size).to(device)
for t in tqdm(
reversed(range(TIMESTEPS)), desc="DDPM Sampling", total=TIMESTEPS, leave=False
):
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long)
x = ddpm_sample(model, x, t_batch, params)
if x.isnan().any():
raise ValueError(f"NaN detected in image at timestep {t}")
return x
def get_ddim_timesteps(
total_timesteps: int, num_sampling_timesteps: int
) -> torch.Tensor:
"""Gets the timesteps used for the DDIM process."""
assert total_timesteps % num_sampling_timesteps == 0
stride = total_timesteps // num_sampling_timesteps
timesteps = torch.arange(0, total_timesteps, stride)
return timesteps.flip(0)
@torch.no_grad()
def ddim_sample_images(
model: torch.nn.Module,
image_size: int,
batch_size: int,
channels: int,
device: torch.device,
params: Dict[str, torch.Tensor],
):
"""Generate new images using the trained model"""
x = torch.randn(batch_size, channels, image_size, image_size).to(device)
timesteps = get_ddim_timesteps(TIMESTEPS, DDIM_TIMESTEPS)
for i in tqdm(range(len(timesteps) - 1), desc="DDIM Sampling"):
t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long)
x = ddim_sample(model, x, t, params)
if x.isnan().any():
raise ValueError(f"NaN detected at timestep {timesteps[i]}")
return x
def show_images(images: torch.Tensor, title=""):
"""Display a batch of images in a grid"""
for idx in range(min(16, len(images))):
plt.subplot(4, 4, idx + 1)
plt.imshow(images[idx])
# plt.imshow(np.transpose(images[idx], (1, 2, 0)))
plt.axis("off")
plt.suptitle(title)
plt.savefig("media/circles-predicted.png")
plt.close()
def compute_statistics(data: np.array, generated: np.array):
data_centers = []
num_bad_samples = 0
# for centers in map(detect_circle_centers, data):
# if len(centers) == 2:
# data_centers.append(np.array(centers))
# else:
# num_bad_samples += 1
with ProcessPoolExecutor(max_workers=8) as executor:
for centers in executor.map(detect_circle_centers, data, chunksize=8):
if len(centers) == 2:
data_centers.append(np.array(centers))
else:
num_bad_samples += 1
if num_bad_samples > 0:
print("num bad samples in data: ", num_bad_samples)
data_centers = np.stack(data_centers, axis=0) # (num_samples, 2, 2)
num_bad_samples = 0
generated_centers = []
with ProcessPoolExecutor(max_workers=16) as executor:
for centers in executor.map(detect_circle_centers, generated, chunksize=8):
if len(centers) == 2:
generated_centers.append(np.array(centers))
else:
num_bad_samples += 1
if num_bad_samples > 0:
print("num bad samples in generated: ", num_bad_samples)
generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2)
# Calculate distances from the center of the image
# image_center = IMAGE_SIZE / 2
# data_distances = np.sqrt(
# (data_centers[:, 0] - image_center) ** 2
# + (data_centers[:, 1] - image_center) ** 2
# )
# generated_distances = np.sqrt(
# (generated_centers[:, 0] - image_center) ** 2
# + (generated_centers[:, 1] - image_center) ** 2
# )
# Create a figure with subplots
plt.figure(figsize=(15, 10))
# Plot histogram of x positions
plt.subplot(2, 2, 1)
sns.histplot(
data_centers[:, :, 0].reshape(-1),
color="blue",
label="Data",
kde=True,
stat="density",
)
sns.histplot(
generated_centers[:, :, 0].reshape(-1),
color="orange",
label="Generated",
kde=True,
stat="density",
)
plt.title("X Position Distribution")
plt.xlabel("X Position")
plt.legend()
# Plot histogram of y positions
plt.subplot(2, 2, 2)
sns.histplot(
data_centers[:, :, 1].reshape(-1),
color="blue",
label="Data",
kde=True,
stat="density",
)
sns.histplot(
generated_centers[:, :, 1].reshape(-1),
color="orange",
label="Generated",
kde=True,
stat="density",
)
plt.title("Y Position Distribution")
plt.xlabel("Y Position")
plt.legend()
# Plot histogram of distances
plt.subplot(2, 2, 3)
distances = np.sqrt(
np.square(data_centers[:, ::2, 0] - data_centers[:, 1::2, 0])
+ np.square(data_centers[:, ::2, 1] - data_centers[:, 1::2, 1])
).squeeze()
generated_distances = np.sqrt(
np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0])
+ np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1])
).squeeze()
sns.histplot(distances, color="blue", label="Data", kde=True, stat="density")
sns.histplot(
generated_distances, color="orange", label="Generated", kde=True, stat="density"
)
plt.title("Distance between circles distribution")
plt.xlabel("Distance")
plt.legend()
# Plot 2D heatmap of center positions
plt.subplot(2, 2, 4)
sns.kdeplot(
x=data_centers[:, :, 0].reshape(-1),
y=data_centers[:, :, 1].reshape(-1),
cmap="Blues",
label="Data",
)
sns.kdeplot(
x=generated_centers[:, :, 0].reshape(-1),
y=generated_centers[:, :, 1].reshape(-1),
cmap="Oranges",
label="Generated",
)
plt.title("2D Heatmap of Center Positions")
plt.xlabel("X Position")
plt.ylabel("Y Position")
plt.xlim(0, IMAGE_SIZE)
plt.ylim(0, IMAGE_SIZE)
plt.legend()
plt.tight_layout()
plt.savefig("media/center_statistics.png")
print("Saved histograms at media/center_statistics.png")
plt.close()
GENERATE_SYNTHETIC_DATA = True
def load_dataset(data_path: str):
@ -292,69 +27,34 @@ def load_dataset(data_path: str):
return data
def plot_bad_centers(generated: np.array):
generated_centers = []
num_bad_samples = 0
with ProcessPoolExecutor(max_workers=16) as executor:
for centers in executor.map(detect_circle_centers, generated, chunksize=8):
if len(centers) == 2:
generated_centers.append(np.array(centers))
else:
num_bad_samples += 1
generated_centers.append(np.zeros((2, 2)))
if num_bad_samples > 0:
print("num bad samples in generated: ", num_bad_samples)
generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2)
generated_distances = np.sqrt(
np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0])
+ np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1])
).squeeze()
mask = generated_distances > 18.0
generated = generated[mask]
show_images(generated)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
plt.figure(figsize=(10, 10))
data = load_dataset("./data/circle_dataset/data32.npy")
data = load_dataset("./data/circle_dataset/data-bigrectangle.npy")
if GENERATE_SYNTHETIC_DATA:
nb_synthetic_samples = 10_000
nb_synthetic_samples = 102_400
params = get_diffusion_params(TIMESTEPS, device, eta=0.0)
model = UNet(32, TIMESTEPS).to(device)
model.load_state_dict(torch.load("model.pkl", weights_only=True))
model = UNet(IMAGE_SIZE, TIMESTEPS).to(device)
model.load_state_dict(torch.load("model-bigrectangle.pkl", weights_only=True))
model.eval()
generated = np.empty_like(data)
chunk = 500
samples = min(nb_synthetic_samples, data.shape[0])
for i in tqdm(range(samples // chunk), desc="Generating synthetic data."):
generated_images = ddpm_sample_images(
model=model,
image_size=IMAGE_SIZE,
batch_size=chunk,
channels=3,
device=device,
params=params,
)
generated_images = torch.permute(generated_images, (0, 2, 3, 1))
generated[i * chunk : (i + 1) * chunk] = (
(generated_images * 255.0)
.clip(min=0.0, max=255.0)
.cpu()
.numpy()
.astype(np.uint8)
)
np.save("./data/circle_dataset/generated32.npy", generated)
generated = generate_synthetic_samples(
model,
image_size=IMAGE_SIZE,
nb_synthetic_samples=nb_synthetic_samples,
params=params,
device=device,
)
np.save("./data/circle_dataset/generated-bigrectangle.npy", generated)
else:
generated = np.load("./data/circle_dataset/generated32.npy")
generated = np.load("./data/circle_dataset/generated-bigrectangle.npy")
visualize_samples(generated, "")
compute_statistics(data, generated)
plot_bad_centers(generated)
# plot_k_centers(generated, 1)
# plot_bad_centers(generated)
# plot_more_than_2_centers(generated)

200
src/ddpm/train.py

@ -1,18 +1,30 @@
from typing import Dict, Callable
from itertools import islice
import subprocess
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt
from model import UNet
from diffusion import get_diffusion_params, extract, TIMESTEPS
from utils import generate_synthetic_samples, compute_statistics
from values import EXPERIMENTS
matplotlib.use("Agg")
# Hyperparameters
NUM_EPOCHS = 1
NUM_EPOCHS = 10
BATCH_SIZE = 512
IMAGE_SIZE = 32
CHANNELS = 3
TIMESTEPS = 1000
DDIM_TIMESTEPS = 500
# Histogram generation
HIST_GENERATED_SAMPLES = 10_000
HIST_DATASET_SAMPLES = (
100_000 # Using the whole dataset becomes super slow, so we take a subset
)
def load_dataset(data_path: str):
@ -24,12 +36,13 @@ def load_dataset(data_path: str):
return data
def create_dataset_loader(data: np.array):
nb_batches = data.shape[0] // BATCH_SIZE
def create_dataset_loader(data: np.array, batch_size: int):
nb_batches = data.shape[0] // batch_size
ids = np.arange(data.shape[0])
np.random.shuffle(ids)
for i in range(nb_batches):
batch_ids = ids[i * BATCH_SIZE : (i + 1) * BATCH_SIZE]
batch_ids = ids[i * batch_size : (i + 1) * batch_size]
yield data[batch_ids]
@ -37,75 +50,9 @@ def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_diffusion_params(
timesteps: int,
device: torch.device,
ddim_timesteps: int = DDIM_TIMESTEPS,
eta=0.0,
) -> Dict[str, torch.Tensor]:
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
betas = linear_beta_schedule(timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
one_over_alphas = 1.0 / torch.sqrt(alphas)
posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
ddim_sigma = eta * torch.sqrt(
(1.0 - alphas_cumprod_prev)
/ (1.0 - alphas_cumprod)
* (1 - alphas_cumprod / alphas_cumprod_prev)
)
return {
# DDPM Parameters
"betas": betas.to(device),
"alphas_cumprod": alphas_cumprod.to(device),
"posterior_variance": posterior_variance.to(device),
"one_over_alphas": one_over_alphas.to(device),
"posterior_mean_coef": posterior_mean_coef.to(device),
# DDIM Parameters
"ddim_sigma": ddim_sigma.to(device),
}
def extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Tensor.shape):
"""Extract coefficients at specified timesteps t"""
batch_size = t.shape[0]
out = a.gather(-1, t)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def get_lr(
it: int,
warmup_iters: int = 80,
lr_decay_iters: int = 900,
min_lr: float = 3e-5,
learning_rate: float = 1e-4,
):
# 1) linear warmup for warmup_iters steps
if it < warmup_iters:
return learning_rate * (it + 1) / (warmup_iters + 1)
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters:
return min_lr
# 3) in between, use cosine decay down to min learning rate
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio)) # coeff ranges 0..1
return min_lr + coeff * (learning_rate - min_lr)
def get_loss_fn(model: torch.nn.Module, params: Dict[str, torch.Tensor]) -> Callable:
device = next(model.parameters()).device
@torch.compile
def loss_fn(x_0):
batch_size = x_0.shape[0]
@ -126,39 +73,37 @@ def get_loss_fn(model: torch.nn.Module, params: Dict[str, torch.Tensor]) -> Call
def train_epoch(
model: torch.nn.Module, optimize, train_loader: DataLoader, loss_fn: Callable
model: torch.nn.Module, optimizer, train_loader: DataLoader, loss_fn: Callable
) -> float:
model.train()
total_loss = 0
device = next(model.parameters()).device
with tqdm(islice(train_loader, 200), leave=False) as pbar:
steps = 0
model.train()
train_losses = []
with tqdm(train_loader, leave=False) as pbar:
for batch in pbar:
# lr = get_lr(steps)
# for param_group in optimizer.param_groups:
# param_group["lr"] = lr
images = torch.tensor(batch, device=device)
optimizer.zero_grad()
loss = loss_fn(images)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
total_loss += loss.item()
steps += 1
pbar.set_description(f"Loss: {loss.item():.4f}")
return total_loss / steps
return train_losses
if __name__ == "__main__":
def main(args):
torch.set_float32_matmul_precision("high")
image_size = EXPERIMENTS[args.experiment_name]["image_size"]
# Set random seed for reproducibility
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(32, TIMESTEPS).to(device)
model = UNet(image_size, TIMESTEPS).to(device)
nb_params = count_parameters(model)
print(f"Total number of parameters: {nb_params}")
@ -166,12 +111,79 @@ if __name__ == "__main__":
params = get_diffusion_params(TIMESTEPS, device)
loss_fn = get_loss_fn(model, params)
data = load_dataset("./data/circle_dataset/data32.npy")
data = load_dataset(f"./data/circle_dataset/data-{args.experiment_name}.npy")
if args.histogram:
ids = np.random.choice(np.arange(data.shape[0]), replace=False)
data_hist_set = data[ids]
# Main training loop
for e in tqdm(range(NUM_EPOCHS)):
train_loader = create_dataset_loader(data)
train_epoch(model, optimizer, train_loader, loss_fn)
train_losses = []
for e in tqdm(range(args.epochs)):
train_loader = create_dataset_loader(data, BATCH_SIZE)
epoch_losses = train_epoch(model, optimizer, train_loader, loss_fn)
train_losses.extend(epoch_losses)
if args.histogram:
generated = generate_synthetic_samples(
model,
image_size=image_size,
nb_synthetic_samples=HIST_GENERATED_SAMPLES,
params=params,
device=device,
)
compute_statistics(
data_hist_set,
generated,
output_path=f"center_statistics_{args.experiment_name}_{e}.png",
)
# Save model after training
torch.save(model.state_dict(), "model.pkl")
torch.save(model.state_dict(), f"model-{args.experiment_name}.pkl")
# Plot training loss curve
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Training Loss")
plt.yscale("log")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.title("Training Loss Curve")
plt.legend()
plt.grid(True)
plt.savefig(f"training_loss_{args.experiment_name}.png")
plt.close()
# Generate animation of the statistics during training
if args.histogram:
# Call ffmpeg to create a video from the images
subprocess.run(
[
"ffmpeg",
"-framerate",
"1",
"-i",
f"center_statistics_{args.experiment_name}_%d.png",
"-c:v",
"libx264",
"-r",
"30",
"-pix_fmt",
"yuv420p",
f"center_statistics_{args.experiment_name}.mp4",
]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a DDPM model.")
parser.add_argument("-H", "--histogram", action="store_true")
parser.add_argument(
"-e",
"--experiment_name",
default="vanilla",
type=str,
choices=EXPERIMENTS.keys(),
)
parser.add_argument("--epochs", default=10, type=int)
args = parser.parse_args()
main(args)

329
src/ddpm/utils.py

@ -0,0 +1,329 @@
import os
import torch
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat
from scipy.ndimage import label, center_of_mass
import numpy as np
from tqdm import tqdm
from values import GREEN, BLUE
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from diffusion import ddpm_sample_images
matplotlib.use("Agg")
def detect_circle_centers(image, circle_colors, threshold=30, min_pixels=12):
"""
Detects centers of circles in an image based on their known colors, filtering out
regions that have less than a specified number of pixels.
This function creates a mask for each provided circle color by selecting pixels
whose RGB values are close to the target color (within a per-channel threshold).
It then labels the connected regions in the mask, filters out small regions based on
the min_pixels parameter, and computes the centers of the remaining regions.
Args:
image (np.ndarray): The image array with shape (H, W, 3).
circle_colors (list of tuple): List of RGB tuples for the circle colors to detect,
e.g. [GREEN, BLUE].
threshold (int): Maximum allowed difference per channel between a pixel and the
target circle color.
min_pixels (int): Minimum number of pixels for a region to be considered valid.
Returns:
centers (list of tuples): List of (row, col) coordinates for each detected circle.
image (np.ndarray): The original image.
"""
centers = []
# Loop over each target circle color.
for color in circle_colors:
# Compute absolute difference between each pixel and the target color.
diff = np.abs(image.astype(np.int16) - np.array(color, dtype=np.int16))
# Create a mask: pixels where all channels are within the threshold.
mask = np.all(diff < threshold, axis=-1)
# Label connected regions in the mask.
labeled, num_features = label(mask)
# Process each labeled region.
for i in range(1, num_features + 1):
# Count the number of pixels in the current region.
region_size = np.sum(labeled == i)
# Skip regions that are smaller than the minimum required.
if region_size < min_pixels:
continue
center = center_of_mass(mask, labeled, i)
centers.append(center)
return centers, image
def compute_statistics(
data: np.array,
generated: np.array,
output_path: str = "center_statistics.png",
):
assert len(data.shape) == 4
assert data.shape[2] == data.shape[3]
image_size = data.shape[2]
data_centers = []
num_bad_samples = 0
with ProcessPoolExecutor(max_workers=8) as executor:
for centers, _ in executor.map(
detect_circle_centers, data, repeat((BLUE, GREEN)), chunksize=8
):
if len(centers) == 2:
data_centers.append(np.array(centers))
else:
num_bad_samples += 1
if num_bad_samples > 0:
print("num bad samples in data: ", num_bad_samples)
data_centers = np.stack(data_centers, axis=0) # (num_samples, 2, 2)
num_bad_samples = 0
generated_centers = []
with ProcessPoolExecutor(max_workers=16) as executor:
for centers, _ in executor.map(
detect_circle_centers, generated, repeat((BLUE, GREEN)), chunksize=8
):
if len(centers) == 2:
generated_centers.append(np.array(centers))
else:
num_bad_samples += 1
if num_bad_samples > 0:
print("num bad samples in generated: ", num_bad_samples)
generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2)
# Create a figure with subplots
plt.figure(figsize=(15, 10))
# Plot histogram of x positions
plt.subplot(2, 2, 1)
sns.histplot(
data_centers[:, :, 0].reshape(-1),
color="blue",
label="Data",
kde=True,
stat="density",
)
sns.histplot(
generated_centers[:, :, 0].reshape(-1),
color="orange",
label="Generated",
kde=True,
stat="density",
)
plt.title("X Position Distribution")
plt.xlabel("X Position")
plt.legend()
# Plot histogram of y positions
plt.subplot(2, 2, 2)
sns.histplot(
data_centers[:, :, 1].reshape(-1),
color="blue",
label="Data",
kde=True,
stat="density",
)
sns.histplot(
generated_centers[:, :, 1].reshape(-1),
color="orange",
label="Generated",
kde=True,
stat="density",
)
plt.title("Y Position Distribution")
plt.xlabel("Y Position")
plt.legend()
# Plot histogram of distances
plt.subplot(2, 2, 3)
distances = np.sqrt(
np.square(data_centers[:, ::2, 0] - data_centers[:, 1::2, 0])
+ np.square(data_centers[:, ::2, 1] - data_centers[:, 1::2, 1])
).squeeze()
generated_distances = np.sqrt(
np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0])
+ np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1])
).squeeze()
sns.histplot(distances, color="blue", label="Data", kde=True, stat="density")
sns.histplot(
generated_distances, color="orange", label="Generated", kde=True, stat="density"
)
plt.title("Distance between circles distribution")
plt.xlabel("Distance")
plt.legend()
# Plot 2D heatmap of center positions
plt.subplot(2, 2, 4)
sns.kdeplot(
x=data_centers[:, :, 0].reshape(-1),
y=data_centers[:, :, 1].reshape(-1),
cmap="Blues",
label="Data",
)
sns.kdeplot(
x=generated_centers[:, :, 0].reshape(-1),
y=generated_centers[:, :, 1].reshape(-1),
cmap="Oranges",
label="Generated",
)
plt.title("2D Heatmap of Center Positions")
plt.xlabel("X Position")
plt.ylabel("Y Position")
plt.xlim(0, image_size)
plt.ylim(0, image_size)
plt.legend()
plt.tight_layout()
output_path = os.path.join("media", output_path)
plt.savefig(output_path)
print(f"Saved histograms at {output_path}")
plt.close()
def plot_k_centers(dataset: np.array, k: int = 1):
"""From a given dataset, plot some samples that have k centers in them."""
images = []
with ProcessPoolExecutor(max_workers=8) as executor:
for centers, image in executor.map(
detect_circle_centers, dataset, repeat((BLUE, GREEN)), chunksize=8
):
if len(centers) == k:
images.append(image)
print("num samples: ", len(images))
images = np.stack(images, axis=0)
show_images(images, f"{k} circles", grid_size=8)
def plot_bad_centers(dataset: np.array, threshold: float = 18.0):
"""From a given dataset, plot some samples that violate the distance constraint."""
generated_centers = []
num_bad_samples = 0
with ProcessPoolExecutor(max_workers=8) as executor:
for centers, _ in executor.map(
detect_circle_centers, dataset, repeat((BLUE, GREEN)), chunksize=8
):
if len(centers) == 2:
generated_centers.append(np.array(centers))
else:
generated_centers.append(np.zeros((2, 2)))
num_bad_samples += 1
if num_bad_samples > 0:
print("num bad samples in generated: ", num_bad_samples)
generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2)
generated_distances = np.sqrt(
np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0])
+ np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1])
).squeeze()
mask = generated_distances > threshold
generated = dataset[mask]
show_images(generated)
def plot_more_than_2_centers(generated: np.array):
nb_samples = 0
images = []
with ProcessPoolExecutor(max_workers=16) as executor:
for centers, image in executor.map(
detect_circle_centers, generated, ((BLUE, GREEN)), chunksize=8
):
if len(centers) == 2:
images.append(image)
nb_samples += 1
print("Nb images multiple centers: ", nb_samples)
show_images(generated)
def show_images(images, title="", grid_size: int = 4):
"""Display a batch of images in a grid"""
for idx in range(min(grid_size**2, len(images))):
ax = plt.subplot(grid_size, grid_size, idx + 1)
ax.imshow(images[idx])
centers, _ = detect_circle_centers(images[idx], (BLUE, GREEN))
for center in centers:
ax.scatter(center[1], center[0], c="red", s=20)
ax.axis("off")
plt.suptitle(title)
plt.savefig("media/circles-predicted.png")
plt.close()
def visualize_samples(
dataset: np.array, output_dir: str, filename: str = "sample_grid.png"
):
"""From a given dataset, visualize some samples"""
# Define the grid size (e.g., 5x5)
grid_size = 5
fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10))
for i in range(grid_size):
for j in range(grid_size):
idx = i * grid_size + j
if idx < len(dataset):
img = dataset[idx]
axes[i, j].imshow(img)
centers, _ = detect_circle_centers(img, (BLUE, GREEN))
# Plot each detected center in red. Note that center_of_mass returns (row, col)
for center in centers:
axes[i, j].scatter(center[1], center[0], c="red", s=20)
axes[i, j].axis("off")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, filename))
plt.close()
def generate_synthetic_samples(
model,
image_size: int,
nb_synthetic_samples: int,
params,
device="cuda",
):
model.eval()
generated = np.empty(
shape=(nb_synthetic_samples, image_size, image_size, 3), dtype=np.uint8
)
chunk = 128
for i in tqdm(
range(nb_synthetic_samples // chunk), desc="Generating synthetic data."
):
generated_images = ddpm_sample_images(
model=model,
image_size=image_size,
batch_size=chunk,
channels=3,
device=device,
params=params,
)
generated_images = torch.permute(generated_images, (0, 2, 3, 1))
generated[i * chunk : (i + 1) * chunk] = (
(generated_images * 255.0)
.clip(min=0.0, max=255.0)
.cpu()
.numpy()
.astype(np.uint8)
)
return generated

29
src/ddpm/values.py

@ -0,0 +1,29 @@
RED = (0xCC, 0x24, 0x1D)
GREEN = (0x98, 0x97, 0x1A)
BLUE = (0x45, 0x85, 0x88)
WHITE = (0xFB, 0xF1, 0xC7)
BACKGROUND = (0x50, 0x49, 0x45)
EXPERIMENTS = {
"vanilla": {
"image_size": 32,
"distance": 11,
"delta": 4,
"radius": 3,
"rectangle_thickness": 0,
},
"rectangle": {
"image_size": 32,
"distance": 11,
"delta": 4,
"radius": 3,
"rectangle_thickness": 4,
},
"bigrectangle": {
"image_size": 32,
"distance": 11,
"delta": 4,
"radius": 3,
"rectangle_thickness": 8,
},
}
Loading…
Cancel
Save