diff --git a/src/ddpm/diffusion.py b/src/ddpm/diffusion.py new file mode 100644 index 0000000..36d31a9 --- /dev/null +++ b/src/ddpm/diffusion.py @@ -0,0 +1,176 @@ +from typing import Dict +import torch +from tqdm import tqdm + +TIMESTEPS = 1000 +DDIM_TIMESTEPS = 500 + + +@torch.compile +@torch.no_grad() +def ddpm_sample( + model: torch.nn.Module, + x: torch.Tensor, + t: torch.Tensor, + params: Dict[str, torch.Tensor], +) -> torch.Tensor: + """Sample from the model at timestep t""" + predicted_noise = model(x, t) + + one_over_alphas = extract(params["one_over_alphas"], t, x.shape) + posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) + + pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) + + posterior_variance = extract(params["posterior_variance"], t, x.shape) + + if t[0] > 0: + noise = torch.randn_like(x) + return pred_mean + torch.sqrt(posterior_variance) * noise + else: + return pred_mean + + +@torch.no_grad() +def ddim_sample( + model: torch.nn.Module, + x: torch.Tensor, + t: torch.Tensor, + params: Dict[str, torch.Tensor], + device, +) -> torch.Tensor: + """Sample from the model in a non-markovian way (DDIM)""" + device = next(model.parameters()).device + + stride = TIMESTEPS // DDIM_TIMESTEPS + t_prev = t - stride + predicted_noise = model(x, t) + + alphas_prod = extract(params["alphas_cumprod"], t, x.shape) + valid_mask = (t_prev >= 0).view(-1, 1, 1, 1) + safe_t_prev = torch.maximum(t_prev, torch.tensor(0, device=device)) + alphas_prod_prev = extract(params["alphas_cumprod"], safe_t_prev, x.shape) + alphas_prod_prev = torch.where( + valid_mask, alphas_prod_prev, torch.ones_like(alphas_prod_prev) + ) + + sigma = extract(params["ddim_sigma"], t, x.shape) + + pred_x0 = (x - (1 - alphas_prod).sqrt() * predicted_noise) / alphas_prod.sqrt() + + pred = ( + alphas_prod_prev.sqrt() * pred_x0 + + (1.0 - alphas_prod_prev).sqrt() * predicted_noise + ) + + if t[0] > 0: + noise = torch.randn_like(x) + pred = pred + noise * sigma + + return pred + + +@torch.no_grad() +def ddpm_sample_images( + model: torch.nn.Module, + image_size: int, + batch_size: int, + channels: int, + device: torch.device, + params: Dict[str, torch.Tensor], +): + """Generate new images using the trained model""" + x = torch.randn(batch_size, channels, image_size, image_size).to(device) + + for t in tqdm( + reversed(range(TIMESTEPS)), desc="DDPM Sampling", total=TIMESTEPS, leave=False + ): + t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) + x = ddpm_sample(model, x, t_batch, params) + + if x.isnan().any(): + raise ValueError(f"NaN detected in image at timestep {t}") + + return x + + +def get_ddim_timesteps( + total_timesteps: int, num_sampling_timesteps: int +) -> torch.Tensor: + """Gets the timesteps used for the DDIM process.""" + assert total_timesteps % num_sampling_timesteps == 0 + stride = total_timesteps // num_sampling_timesteps + timesteps = torch.arange(0, total_timesteps, stride) + return timesteps.flip(0) + + +@torch.no_grad() +def ddim_sample_images( + model: torch.nn.Module, + image_size: int, + batch_size: int, + channels: int, + device: torch.device, + params: Dict[str, torch.Tensor], +): + """Generate new images using the trained model""" + x = torch.randn(batch_size, channels, image_size, image_size).to(device) + + timesteps = get_ddim_timesteps(TIMESTEPS, DDIM_TIMESTEPS) + + for i in tqdm(range(len(timesteps) - 1), desc="DDIM Sampling"): + t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long) + x = ddim_sample(model, x, t, params) + + if x.isnan().any(): + raise ValueError(f"NaN detected at timestep {timesteps[i]}") + + return x + + +def extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Tensor.shape): + """Extract coefficients at specified timesteps t""" + batch_size = t.shape[0] + out = a.gather(-1, t) + return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) + + +def get_diffusion_params( + timesteps: int, + device: torch.device, + ddim_timesteps: int = DDIM_TIMESTEPS, + eta=0.0, +) -> Dict[str, torch.Tensor]: + def linear_beta_schedule(timesteps): + beta_start = 0.0001 + beta_end = 0.02 + return torch.linspace(beta_start, beta_end, timesteps) + + betas = linear_beta_schedule(timesteps) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) + + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + + one_over_alphas = 1.0 / torch.sqrt(alphas) + posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod + + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + + ddim_sigma = eta * torch.sqrt( + (1.0 - alphas_cumprod_prev) + / (1.0 - alphas_cumprod) + * (1 - alphas_cumprod / alphas_cumprod_prev) + ) + + return { + # DDPM Parameters + "betas": betas.to(device), + "alphas_cumprod": alphas_cumprod.to(device), + "posterior_variance": posterior_variance.to(device), + "one_over_alphas": one_over_alphas.to(device), + "posterior_mean_coef": posterior_mean_coef.to(device), + # DDIM Parameters + "ddim_sigma": ddim_sigma.to(device), + } diff --git a/src/ddpm/generate_circle_dataset.py b/src/ddpm/generate_circle_dataset.py index 44f6366..fdea3f3 100644 --- a/src/ddpm/generate_circle_dataset.py +++ b/src/ddpm/generate_circle_dataset.py @@ -1,5 +1,5 @@ +import argparse import numpy as np -from scipy.ndimage import label, center_of_mass from PIL import Image, ImageDraw from tqdm import tqdm @@ -8,17 +8,25 @@ import os from concurrent.futures import ProcessPoolExecutor from itertools import repeat -import matplotlib.pyplot as plt +import matplotlib +from utils import visualize_samples -RED = (0xCC, 0x24, 0x1D) -GREEN = (0x98, 0x97, 0x1A) -BLUE = (0x45, 0x85, 0x88) -BACKGROUND = (0x50, 0x49, 0x45) +from values import GREEN, BLUE, WHITE, BACKGROUND, EXPERIMENTS + +matplotlib.use("Agg") def create_sample_antialiased( - id: int, image_size: int, distance: int, radius: int, delta: int, scale=4 + id: int, + image_size: int, + distance: int, + radius: int, + delta: int, + scale=4, + enable_constraint: bool = True, + enable_rectangle: bool = False, + rectangle_thickness: int = 0, ): # Scale up the image dimensions high_res_size = image_size * scale @@ -37,9 +45,47 @@ def create_sample_antialiased( x1, y1 = np.random.randint( low=high_res_radius, high=high_res_size - high_res_radius, size=2 ) + + if not enable_constraint: + break + dist = np.sqrt((x0 - x1) ** 2 + (y0 - y1) ** 2) - # Draw anti-aliased circles using PIL's ellipse method + if enable_rectangle: + # Compute the vector from circle0 to circle1. + dx = x1 - x0 + dy = y1 - y0 + d = np.sqrt(dx**2 + dy**2) + if d != 0: + ux = dx / d + uy = dy / d + else: + ux, uy = 0, 0 + + # Extend endpoints to fully enclose both circles: + start_x = x0 - high_res_radius * ux + start_y = y0 - high_res_radius * uy + end_x = x1 + high_res_radius * ux + end_y = y1 + high_res_radius * uy + + # Ensure the rectangle is thick enough to enclose the entire circles. + thickness = max(rectangle_thickness * scale, 2 * high_res_radius) + half_thickness = thickness / 2.0 + + # Compute perpendicular vector to (ux, uy) + perp_x = -uy + perp_y = ux + + # Compute the four corners of the rectangle + p1 = (start_x + half_thickness * perp_x, start_y + half_thickness * perp_y) + p2 = (start_x - half_thickness * perp_x, start_y - half_thickness * perp_y) + p3 = (end_x - half_thickness * perp_x, end_y - half_thickness * perp_y) + p4 = (end_x + half_thickness * perp_x, end_y + half_thickness * perp_y) + + # Draw the white rectangle (as a polygon) + draw.polygon([p1, p2, p3, p4], fill=WHITE) + + # Draw anti-aliased circles using PIL's circle method (or ellipse if needed) draw.circle((x0, y0), high_res_radius, fill=GREEN) draw.circle((x1, y1), high_res_radius, fill=BLUE) @@ -48,71 +94,16 @@ def create_sample_antialiased( return id, np.array(im) -def create_sample(id: int, image_size: int, distance: int, radius: int, delta: int): - # Create a blank image - img = np.full( - shape=(image_size, image_size, 3), fill_value=BACKGROUND, dtype=np.uint8 - ) - - # Compute random centers until they are inside the distance range - dist = float("inf") - while (dist < distance - delta) or (dist > distance + delta): - x0, y0 = np.random.randint( - low=radius, high=image_size - radius, size=2, dtype=np.int32 - ) - x1, y1 = np.random.randint( - low=radius, high=image_size - radius, size=2, dtype=np.int32 - ) - - dist = np.sqrt((x0 - x1) ** 2 + (y0 - y1) ** 2) - - # Draw the circles - - xx, yy = np.mgrid[:image_size, :image_size] - - # Create boolean masks for the circles based on the radius - mask0 = (xx - x0) ** 2 + (yy - y0) ** 2 <= radius**2 - mask1 = (xx - x1) ** 2 + (yy - y1) ** 2 <= radius**2 - - # Apply the colors to the pixels where the mask is True - img[mask0] = GREEN - img[mask1] = BLUE - - return id, img - - -def detect_circle_centers(image, background=BACKGROUND, threshold=30): - """ - Detects centers of circles in an image by finding connected components - that differ from the background color. - - Args: - image (np.ndarray): The image array with shape (H, W, 3). - background (np.ndarray): The background color to ignore. - threshold (int): The minimum per-channel difference to consider a pixel - as part of a circle. - - Returns: - centers (list of tuples): List of (row, col) coordinates for each detected circle. - """ - # Compute the absolute difference from the background for each pixel. - diff = np.abs(image.astype(np.int16) - np.array(background).astype(np.int16)) - # Create a mask where any channel difference exceeds the threshold. - mask = np.any(diff > threshold, axis=-1) - - # Label connected regions in the mask. - labeled, num_features = label(mask) - - centers = [] - # Compute the center of mass for each labeled region. - for i in range(1, num_features + 1): - center = center_of_mass(mask, labeled, i) - centers.append(center) - - return centers - - -def generate_circle_dataset(num_samples, image_size, radius, distance, delta): +def generate_circle_dataset( + num_samples, + image_size, + radius, + distance, + delta, + enable_constraint: bool = True, + enable_rectangle: bool = False, + rectangle_thickness: int = 0, +): """ Generate a dataset of images with two circles (red and blue) and save as numpy tensors. @@ -132,57 +123,51 @@ def generate_circle_dataset(num_samples, image_size, radius, distance, delta): repeat(distance), repeat(radius), repeat(delta), + repeat(4), + repeat(enable_constraint), + repeat(enable_rectangle), + repeat(rectangle_thickness), chunksize=100, ): yield i, sample -def visualize_samples(dataset: np.array, output_dir: str): - # Define the grid size (e.g., 5x5) - grid_size = 5 - fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10)) - - for i in range(grid_size): - for j in range(grid_size): - idx = i * grid_size + j - if idx < len(dataset): - img = dataset[idx] - axes[i, j].imshow(img) - centers = detect_circle_centers(img) - # Plot each detected center in red. Note that center_of_mass returns (row, col) - for center in centers: - axes[i, j].scatter(center[1], center[0], c="red", s=20) - axes[i, j].axis("off") - - plt.tight_layout() - plt.savefig(os.path.join(output_dir, "sample_grid.png")) - plt.close() - - -if __name__ == "__main__": +def main(args, experiment): # Create output directory if it doesn't exist - total_samples = 1_000_000 - image_size = 32 - distance = 11 - delta = 4 - radius = 3 + image_size = experiment["image_size"] - output_dir = "data/circle_dataset" - os.makedirs(output_dir, exist_ok=True) + os.makedirs(args.output_dir, exist_ok=True) - dataset = np.empty((total_samples, image_size, image_size, 3), dtype=np.uint8) + dataset = np.empty((args.total_samples, image_size, image_size, 3), dtype=np.uint8) iterator = generate_circle_dataset( image_size=image_size, - num_samples=total_samples, - distance=distance, - delta=delta, - radius=radius, + num_samples=args.total_samples, + distance=experiment["distance"], + delta=experiment["delta"], + radius=experiment["radius"], + enable_constraint=True, + enable_rectangle=experiment["rectangle_thickness"] > 0, + rectangle_thickness=experiment["rectangle_thickness"], ) - for i, sample in tqdm(iterator, total=total_samples): + for i, sample in tqdm(iterator, total=args.total_samples): dataset[i] = sample - visualize_samples(dataset, output_dir) + visualize_samples( + dataset, args.output_dir, f"sample_grid_{args.experiment_name}.png" + ) # Save the dataset - np.save(os.path.join(output_dir, "data32.npy"), dataset) + np.save(os.path.join(args.output_dir, f"data-{args.experiment_name}.npy"), dataset) # np.savez_compressed(os.path.join(output_dir, "data.npy.npz"), dataset) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate a dataset of circles") + parser.add_argument( + "-e", "--experiment_name", required=True, type=str, choices=EXPERIMENTS.keys() + ) + parser.add_argument("-o", "--output_dir", default="data/circle_dataset") + parser.add_argument("-t", "--total_samples", default=1_000_000, type=int) + + args = parser.parse_args() + main(args, EXPERIMENTS[args.experiment_name]) diff --git a/src/ddpm/sample.py b/src/ddpm/sample.py index ed2efea..4ed1f5f 100644 --- a/src/ddpm/sample.py +++ b/src/ddpm/sample.py @@ -1,288 +1,23 @@ -from typing import Dict -from concurrent.futures import ProcessPoolExecutor import torch -from tqdm import tqdm import matplotlib import matplotlib.pyplot as plt -import seaborn as sns import numpy as np -from train import extract, get_diffusion_params -from train import TIMESTEPS, IMAGE_SIZE, DDIM_TIMESTEPS +from train import get_diffusion_params +from train import TIMESTEPS, IMAGE_SIZE from model import UNet -from generate_circle_dataset import detect_circle_centers +from generate_circle_dataset import ( + visualize_samples, +) + +from utils import compute_statistics, generate_synthetic_samples matplotlib.use("Agg") torch.manual_seed(1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -GENERATE_SYNTHETIC_DATA = False - - -@torch.compile -@torch.no_grad() -def ddpm_sample( - model: torch.nn.Module, - x: torch.Tensor, - t: torch.Tensor, - params: Dict[str, torch.Tensor], -) -> torch.Tensor: - """Sample from the model at timestep t""" - predicted_noise = model(x, t) - - one_over_alphas = extract(params["one_over_alphas"], t, x.shape) - posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) - - pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) - - posterior_variance = extract(params["posterior_variance"], t, x.shape) - - if t[0] > 0: - noise = torch.randn_like(x) - return pred_mean + torch.sqrt(posterior_variance) * noise - else: - return pred_mean - - -@torch.no_grad() -def ddim_sample( - model: torch.nn.Module, - x: torch.Tensor, - t: torch.Tensor, - params: Dict[str, torch.Tensor], -) -> torch.Tensor: - """Sample from the model in a non-markovian way (DDIM)""" - stride = TIMESTEPS // DDIM_TIMESTEPS - t_prev = t - stride - predicted_noise = model(x, t) - - alphas_prod = extract(params["alphas_cumprod"], t, x.shape) - valid_mask = (t_prev >= 0).view(-1, 1, 1, 1) - safe_t_prev = torch.maximum(t_prev, torch.tensor(0, device=device)) - alphas_prod_prev = extract(params["alphas_cumprod"], safe_t_prev, x.shape) - alphas_prod_prev = torch.where( - valid_mask, alphas_prod_prev, torch.ones_like(alphas_prod_prev) - ) - - sigma = extract(params["ddim_sigma"], t, x.shape) - - pred_x0 = (x - (1 - alphas_prod).sqrt() * predicted_noise) / alphas_prod.sqrt() - - pred = ( - alphas_prod_prev.sqrt() * pred_x0 - + (1.0 - alphas_prod_prev).sqrt() * predicted_noise - ) - - if t[0] > 0: - noise = torch.randn_like(x) - pred = pred + noise * sigma - - return pred - - -@torch.no_grad() -def ddpm_sample_images( - model: torch.nn.Module, - image_size: int, - batch_size: int, - channels: int, - device: torch.device, - params: Dict[str, torch.Tensor], -): - """Generate new images using the trained model""" - x = torch.randn(batch_size, channels, image_size, image_size).to(device) - - for t in tqdm( - reversed(range(TIMESTEPS)), desc="DDPM Sampling", total=TIMESTEPS, leave=False - ): - t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) - x = ddpm_sample(model, x, t_batch, params) - - if x.isnan().any(): - raise ValueError(f"NaN detected in image at timestep {t}") - - return x - - -def get_ddim_timesteps( - total_timesteps: int, num_sampling_timesteps: int -) -> torch.Tensor: - """Gets the timesteps used for the DDIM process.""" - assert total_timesteps % num_sampling_timesteps == 0 - stride = total_timesteps // num_sampling_timesteps - timesteps = torch.arange(0, total_timesteps, stride) - return timesteps.flip(0) - - -@torch.no_grad() -def ddim_sample_images( - model: torch.nn.Module, - image_size: int, - batch_size: int, - channels: int, - device: torch.device, - params: Dict[str, torch.Tensor], -): - """Generate new images using the trained model""" - x = torch.randn(batch_size, channels, image_size, image_size).to(device) - - timesteps = get_ddim_timesteps(TIMESTEPS, DDIM_TIMESTEPS) - - for i in tqdm(range(len(timesteps) - 1), desc="DDIM Sampling"): - t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long) - x = ddim_sample(model, x, t, params) - - if x.isnan().any(): - raise ValueError(f"NaN detected at timestep {timesteps[i]}") - - return x - - -def show_images(images: torch.Tensor, title=""): - """Display a batch of images in a grid""" - for idx in range(min(16, len(images))): - plt.subplot(4, 4, idx + 1) - plt.imshow(images[idx]) - # plt.imshow(np.transpose(images[idx], (1, 2, 0))) - plt.axis("off") - plt.suptitle(title) - plt.savefig("media/circles-predicted.png") - plt.close() - - -def compute_statistics(data: np.array, generated: np.array): - data_centers = [] - num_bad_samples = 0 - # for centers in map(detect_circle_centers, data): - # if len(centers) == 2: - # data_centers.append(np.array(centers)) - # else: - # num_bad_samples += 1 - with ProcessPoolExecutor(max_workers=8) as executor: - for centers in executor.map(detect_circle_centers, data, chunksize=8): - if len(centers) == 2: - data_centers.append(np.array(centers)) - else: - num_bad_samples += 1 - - if num_bad_samples > 0: - print("num bad samples in data: ", num_bad_samples) - - data_centers = np.stack(data_centers, axis=0) # (num_samples, 2, 2) - - num_bad_samples = 0 - generated_centers = [] - with ProcessPoolExecutor(max_workers=16) as executor: - for centers in executor.map(detect_circle_centers, generated, chunksize=8): - if len(centers) == 2: - generated_centers.append(np.array(centers)) - else: - num_bad_samples += 1 - - if num_bad_samples > 0: - print("num bad samples in generated: ", num_bad_samples) - - generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2) - - # Calculate distances from the center of the image - # image_center = IMAGE_SIZE / 2 - # data_distances = np.sqrt( - # (data_centers[:, 0] - image_center) ** 2 - # + (data_centers[:, 1] - image_center) ** 2 - # ) - # generated_distances = np.sqrt( - # (generated_centers[:, 0] - image_center) ** 2 - # + (generated_centers[:, 1] - image_center) ** 2 - # ) - - # Create a figure with subplots - - plt.figure(figsize=(15, 10)) - - # Plot histogram of x positions - plt.subplot(2, 2, 1) - sns.histplot( - data_centers[:, :, 0].reshape(-1), - color="blue", - label="Data", - kde=True, - stat="density", - ) - sns.histplot( - generated_centers[:, :, 0].reshape(-1), - color="orange", - label="Generated", - kde=True, - stat="density", - ) - plt.title("X Position Distribution") - plt.xlabel("X Position") - plt.legend() - - # Plot histogram of y positions - plt.subplot(2, 2, 2) - sns.histplot( - data_centers[:, :, 1].reshape(-1), - color="blue", - label="Data", - kde=True, - stat="density", - ) - sns.histplot( - generated_centers[:, :, 1].reshape(-1), - color="orange", - label="Generated", - kde=True, - stat="density", - ) - plt.title("Y Position Distribution") - plt.xlabel("Y Position") - plt.legend() - - # Plot histogram of distances - plt.subplot(2, 2, 3) - distances = np.sqrt( - np.square(data_centers[:, ::2, 0] - data_centers[:, 1::2, 0]) - + np.square(data_centers[:, ::2, 1] - data_centers[:, 1::2, 1]) - ).squeeze() - generated_distances = np.sqrt( - np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0]) - + np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1]) - ).squeeze() - sns.histplot(distances, color="blue", label="Data", kde=True, stat="density") - sns.histplot( - generated_distances, color="orange", label="Generated", kde=True, stat="density" - ) - plt.title("Distance between circles distribution") - plt.xlabel("Distance") - plt.legend() - - # Plot 2D heatmap of center positions - plt.subplot(2, 2, 4) - sns.kdeplot( - x=data_centers[:, :, 0].reshape(-1), - y=data_centers[:, :, 1].reshape(-1), - cmap="Blues", - label="Data", - ) - sns.kdeplot( - x=generated_centers[:, :, 0].reshape(-1), - y=generated_centers[:, :, 1].reshape(-1), - cmap="Oranges", - label="Generated", - ) - plt.title("2D Heatmap of Center Positions") - plt.xlabel("X Position") - plt.ylabel("Y Position") - plt.xlim(0, IMAGE_SIZE) - plt.ylim(0, IMAGE_SIZE) - plt.legend() - - plt.tight_layout() - plt.savefig("media/center_statistics.png") - print("Saved histograms at media/center_statistics.png") - plt.close() +GENERATE_SYNTHETIC_DATA = True def load_dataset(data_path: str): @@ -292,69 +27,34 @@ def load_dataset(data_path: str): return data -def plot_bad_centers(generated: np.array): - generated_centers = [] - num_bad_samples = 0 - with ProcessPoolExecutor(max_workers=16) as executor: - for centers in executor.map(detect_circle_centers, generated, chunksize=8): - if len(centers) == 2: - generated_centers.append(np.array(centers)) - else: - num_bad_samples += 1 - generated_centers.append(np.zeros((2, 2))) - - if num_bad_samples > 0: - print("num bad samples in generated: ", num_bad_samples) - - generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2) - generated_distances = np.sqrt( - np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0]) - + np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1]) - ).squeeze() - - mask = generated_distances > 18.0 - generated = generated[mask] - show_images(generated) - - if __name__ == "__main__": torch.set_float32_matmul_precision("high") plt.figure(figsize=(10, 10)) - data = load_dataset("./data/circle_dataset/data32.npy") + data = load_dataset("./data/circle_dataset/data-bigrectangle.npy") if GENERATE_SYNTHETIC_DATA: - nb_synthetic_samples = 10_000 + nb_synthetic_samples = 102_400 params = get_diffusion_params(TIMESTEPS, device, eta=0.0) - model = UNet(32, TIMESTEPS).to(device) - model.load_state_dict(torch.load("model.pkl", weights_only=True)) + model = UNet(IMAGE_SIZE, TIMESTEPS).to(device) + model.load_state_dict(torch.load("model-bigrectangle.pkl", weights_only=True)) model.eval() - generated = np.empty_like(data) - chunk = 500 - samples = min(nb_synthetic_samples, data.shape[0]) - for i in tqdm(range(samples // chunk), desc="Generating synthetic data."): - generated_images = ddpm_sample_images( - model=model, - image_size=IMAGE_SIZE, - batch_size=chunk, - channels=3, - device=device, - params=params, - ) - generated_images = torch.permute(generated_images, (0, 2, 3, 1)) - generated[i * chunk : (i + 1) * chunk] = ( - (generated_images * 255.0) - .clip(min=0.0, max=255.0) - .cpu() - .numpy() - .astype(np.uint8) - ) - np.save("./data/circle_dataset/generated32.npy", generated) + generated = generate_synthetic_samples( + model, + image_size=IMAGE_SIZE, + nb_synthetic_samples=nb_synthetic_samples, + params=params, + device=device, + ) + np.save("./data/circle_dataset/generated-bigrectangle.npy", generated) else: - generated = np.load("./data/circle_dataset/generated32.npy") + generated = np.load("./data/circle_dataset/generated-bigrectangle.npy") + visualize_samples(generated, "") compute_statistics(data, generated) - plot_bad_centers(generated) + # plot_k_centers(generated, 1) + # plot_bad_centers(generated) + # plot_more_than_2_centers(generated) diff --git a/src/ddpm/train.py b/src/ddpm/train.py index 40872d5..a48dcff 100644 --- a/src/ddpm/train.py +++ b/src/ddpm/train.py @@ -1,18 +1,30 @@ from typing import Dict, Callable -from itertools import islice +import subprocess +import argparse import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm +import matplotlib +import matplotlib.pyplot as plt from model import UNet +from diffusion import get_diffusion_params, extract, TIMESTEPS +from utils import generate_synthetic_samples, compute_statistics +from values import EXPERIMENTS + +matplotlib.use("Agg") + # Hyperparameters -NUM_EPOCHS = 1 +NUM_EPOCHS = 10 BATCH_SIZE = 512 -IMAGE_SIZE = 32 CHANNELS = 3 -TIMESTEPS = 1000 -DDIM_TIMESTEPS = 500 + +# Histogram generation +HIST_GENERATED_SAMPLES = 10_000 +HIST_DATASET_SAMPLES = ( + 100_000 # Using the whole dataset becomes super slow, so we take a subset +) def load_dataset(data_path: str): @@ -24,12 +36,13 @@ def load_dataset(data_path: str): return data -def create_dataset_loader(data: np.array): - nb_batches = data.shape[0] // BATCH_SIZE +def create_dataset_loader(data: np.array, batch_size: int): + nb_batches = data.shape[0] // batch_size ids = np.arange(data.shape[0]) + np.random.shuffle(ids) for i in range(nb_batches): - batch_ids = ids[i * BATCH_SIZE : (i + 1) * BATCH_SIZE] + batch_ids = ids[i * batch_size : (i + 1) * batch_size] yield data[batch_ids] @@ -37,75 +50,9 @@ def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) -def get_diffusion_params( - timesteps: int, - device: torch.device, - ddim_timesteps: int = DDIM_TIMESTEPS, - eta=0.0, -) -> Dict[str, torch.Tensor]: - def linear_beta_schedule(timesteps): - beta_start = 0.0001 - beta_end = 0.02 - return torch.linspace(beta_start, beta_end, timesteps) - - betas = linear_beta_schedule(timesteps) - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, dim=0) - alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) - - sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) - - one_over_alphas = 1.0 / torch.sqrt(alphas) - posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod - - posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) - - ddim_sigma = eta * torch.sqrt( - (1.0 - alphas_cumprod_prev) - / (1.0 - alphas_cumprod) - * (1 - alphas_cumprod / alphas_cumprod_prev) - ) - - return { - # DDPM Parameters - "betas": betas.to(device), - "alphas_cumprod": alphas_cumprod.to(device), - "posterior_variance": posterior_variance.to(device), - "one_over_alphas": one_over_alphas.to(device), - "posterior_mean_coef": posterior_mean_coef.to(device), - # DDIM Parameters - "ddim_sigma": ddim_sigma.to(device), - } - - -def extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Tensor.shape): - """Extract coefficients at specified timesteps t""" - batch_size = t.shape[0] - out = a.gather(-1, t) - return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) - - -def get_lr( - it: int, - warmup_iters: int = 80, - lr_decay_iters: int = 900, - min_lr: float = 3e-5, - learning_rate: float = 1e-4, -): - # 1) linear warmup for warmup_iters steps - if it < warmup_iters: - return learning_rate * (it + 1) / (warmup_iters + 1) - # 2) if it > lr_decay_iters, return min learning rate - if it > lr_decay_iters: - return min_lr - # 3) in between, use cosine decay down to min learning rate - decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) - assert 0 <= decay_ratio <= 1 - coeff = 0.5 * (1.0 + np.cos(np.pi * decay_ratio)) # coeff ranges 0..1 - return min_lr + coeff * (learning_rate - min_lr) - - def get_loss_fn(model: torch.nn.Module, params: Dict[str, torch.Tensor]) -> Callable: + device = next(model.parameters()).device + @torch.compile def loss_fn(x_0): batch_size = x_0.shape[0] @@ -126,39 +73,37 @@ def get_loss_fn(model: torch.nn.Module, params: Dict[str, torch.Tensor]) -> Call def train_epoch( - model: torch.nn.Module, optimize, train_loader: DataLoader, loss_fn: Callable + model: torch.nn.Module, optimizer, train_loader: DataLoader, loss_fn: Callable ) -> float: - model.train() - total_loss = 0 + device = next(model.parameters()).device - with tqdm(islice(train_loader, 200), leave=False) as pbar: - steps = 0 + model.train() + train_losses = [] + with tqdm(train_loader, leave=False) as pbar: for batch in pbar: - # lr = get_lr(steps) - # for param_group in optimizer.param_groups: - # param_group["lr"] = lr images = torch.tensor(batch, device=device) optimizer.zero_grad() loss = loss_fn(images) loss.backward() optimizer.step() + train_losses.append(loss.item()) - total_loss += loss.item() - steps += 1 pbar.set_description(f"Loss: {loss.item():.4f}") - return total_loss / steps + return train_losses -if __name__ == "__main__": +def main(args): torch.set_float32_matmul_precision("high") + image_size = EXPERIMENTS[args.experiment_name]["image_size"] + # Set random seed for reproducibility torch.manual_seed(42) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = UNet(32, TIMESTEPS).to(device) + model = UNet(image_size, TIMESTEPS).to(device) nb_params = count_parameters(model) print(f"Total number of parameters: {nb_params}") @@ -166,12 +111,79 @@ if __name__ == "__main__": params = get_diffusion_params(TIMESTEPS, device) loss_fn = get_loss_fn(model, params) - data = load_dataset("./data/circle_dataset/data32.npy") + data = load_dataset(f"./data/circle_dataset/data-{args.experiment_name}.npy") + + if args.histogram: + ids = np.random.choice(np.arange(data.shape[0]), replace=False) + data_hist_set = data[ids] # Main training loop - for e in tqdm(range(NUM_EPOCHS)): - train_loader = create_dataset_loader(data) - train_epoch(model, optimizer, train_loader, loss_fn) + train_losses = [] + for e in tqdm(range(args.epochs)): + train_loader = create_dataset_loader(data, BATCH_SIZE) + epoch_losses = train_epoch(model, optimizer, train_loader, loss_fn) + train_losses.extend(epoch_losses) + + if args.histogram: + generated = generate_synthetic_samples( + model, + image_size=image_size, + nb_synthetic_samples=HIST_GENERATED_SAMPLES, + params=params, + device=device, + ) + compute_statistics( + data_hist_set, + generated, + output_path=f"center_statistics_{args.experiment_name}_{e}.png", + ) # Save model after training - torch.save(model.state_dict(), "model.pkl") + torch.save(model.state_dict(), f"model-{args.experiment_name}.pkl") + + # Plot training loss curve + plt.figure(figsize=(10, 5)) + plt.plot(train_losses, label="Training Loss") + plt.yscale("log") + plt.xlabel("Batch") + plt.ylabel("Loss") + plt.title("Training Loss Curve") + plt.legend() + plt.grid(True) + plt.savefig(f"training_loss_{args.experiment_name}.png") + plt.close() + + # Generate animation of the statistics during training + if args.histogram: + # Call ffmpeg to create a video from the images + subprocess.run( + [ + "ffmpeg", + "-framerate", + "1", + "-i", + f"center_statistics_{args.experiment_name}_%d.png", + "-c:v", + "libx264", + "-r", + "30", + "-pix_fmt", + "yuv420p", + f"center_statistics_{args.experiment_name}.mp4", + ] + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train a DDPM model.") + parser.add_argument("-H", "--histogram", action="store_true") + parser.add_argument( + "-e", + "--experiment_name", + default="vanilla", + type=str, + choices=EXPERIMENTS.keys(), + ) + parser.add_argument("--epochs", default=10, type=int) + args = parser.parse_args() + main(args) diff --git a/src/ddpm/utils.py b/src/ddpm/utils.py new file mode 100644 index 0000000..85a4f4a --- /dev/null +++ b/src/ddpm/utils.py @@ -0,0 +1,329 @@ +import os +import torch +from concurrent.futures import ProcessPoolExecutor +from itertools import repeat + +from scipy.ndimage import label, center_of_mass +import numpy as np +from tqdm import tqdm +from values import GREEN, BLUE + +import matplotlib +import matplotlib.pyplot as plt +import seaborn as sns + +from diffusion import ddpm_sample_images + +matplotlib.use("Agg") + + +def detect_circle_centers(image, circle_colors, threshold=30, min_pixels=12): + """ + Detects centers of circles in an image based on their known colors, filtering out + regions that have less than a specified number of pixels. + + This function creates a mask for each provided circle color by selecting pixels + whose RGB values are close to the target color (within a per-channel threshold). + It then labels the connected regions in the mask, filters out small regions based on + the min_pixels parameter, and computes the centers of the remaining regions. + + Args: + image (np.ndarray): The image array with shape (H, W, 3). + circle_colors (list of tuple): List of RGB tuples for the circle colors to detect, + e.g. [GREEN, BLUE]. + threshold (int): Maximum allowed difference per channel between a pixel and the + target circle color. + min_pixels (int): Minimum number of pixels for a region to be considered valid. + + Returns: + centers (list of tuples): List of (row, col) coordinates for each detected circle. + image (np.ndarray): The original image. + """ + centers = [] + # Loop over each target circle color. + for color in circle_colors: + # Compute absolute difference between each pixel and the target color. + diff = np.abs(image.astype(np.int16) - np.array(color, dtype=np.int16)) + # Create a mask: pixels where all channels are within the threshold. + mask = np.all(diff < threshold, axis=-1) + + # Label connected regions in the mask. + labeled, num_features = label(mask) + + # Process each labeled region. + for i in range(1, num_features + 1): + # Count the number of pixels in the current region. + region_size = np.sum(labeled == i) + # Skip regions that are smaller than the minimum required. + if region_size < min_pixels: + continue + + center = center_of_mass(mask, labeled, i) + centers.append(center) + + return centers, image + + +def compute_statistics( + data: np.array, + generated: np.array, + output_path: str = "center_statistics.png", +): + assert len(data.shape) == 4 + assert data.shape[2] == data.shape[3] + image_size = data.shape[2] + + data_centers = [] + num_bad_samples = 0 + with ProcessPoolExecutor(max_workers=8) as executor: + for centers, _ in executor.map( + detect_circle_centers, data, repeat((BLUE, GREEN)), chunksize=8 + ): + if len(centers) == 2: + data_centers.append(np.array(centers)) + else: + num_bad_samples += 1 + + if num_bad_samples > 0: + print("num bad samples in data: ", num_bad_samples) + + data_centers = np.stack(data_centers, axis=0) # (num_samples, 2, 2) + + num_bad_samples = 0 + generated_centers = [] + with ProcessPoolExecutor(max_workers=16) as executor: + for centers, _ in executor.map( + detect_circle_centers, generated, repeat((BLUE, GREEN)), chunksize=8 + ): + if len(centers) == 2: + generated_centers.append(np.array(centers)) + else: + num_bad_samples += 1 + + if num_bad_samples > 0: + print("num bad samples in generated: ", num_bad_samples) + + generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2) + + # Create a figure with subplots + plt.figure(figsize=(15, 10)) + + # Plot histogram of x positions + plt.subplot(2, 2, 1) + sns.histplot( + data_centers[:, :, 0].reshape(-1), + color="blue", + label="Data", + kde=True, + stat="density", + ) + sns.histplot( + generated_centers[:, :, 0].reshape(-1), + color="orange", + label="Generated", + kde=True, + stat="density", + ) + plt.title("X Position Distribution") + plt.xlabel("X Position") + plt.legend() + + # Plot histogram of y positions + plt.subplot(2, 2, 2) + sns.histplot( + data_centers[:, :, 1].reshape(-1), + color="blue", + label="Data", + kde=True, + stat="density", + ) + sns.histplot( + generated_centers[:, :, 1].reshape(-1), + color="orange", + label="Generated", + kde=True, + stat="density", + ) + plt.title("Y Position Distribution") + plt.xlabel("Y Position") + plt.legend() + + # Plot histogram of distances + plt.subplot(2, 2, 3) + distances = np.sqrt( + np.square(data_centers[:, ::2, 0] - data_centers[:, 1::2, 0]) + + np.square(data_centers[:, ::2, 1] - data_centers[:, 1::2, 1]) + ).squeeze() + generated_distances = np.sqrt( + np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0]) + + np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1]) + ).squeeze() + sns.histplot(distances, color="blue", label="Data", kde=True, stat="density") + sns.histplot( + generated_distances, color="orange", label="Generated", kde=True, stat="density" + ) + plt.title("Distance between circles distribution") + plt.xlabel("Distance") + plt.legend() + + # Plot 2D heatmap of center positions + plt.subplot(2, 2, 4) + sns.kdeplot( + x=data_centers[:, :, 0].reshape(-1), + y=data_centers[:, :, 1].reshape(-1), + cmap="Blues", + label="Data", + ) + sns.kdeplot( + x=generated_centers[:, :, 0].reshape(-1), + y=generated_centers[:, :, 1].reshape(-1), + cmap="Oranges", + label="Generated", + ) + plt.title("2D Heatmap of Center Positions") + plt.xlabel("X Position") + plt.ylabel("Y Position") + plt.xlim(0, image_size) + plt.ylim(0, image_size) + plt.legend() + + plt.tight_layout() + output_path = os.path.join("media", output_path) + plt.savefig(output_path) + print(f"Saved histograms at {output_path}") + plt.close() + + +def plot_k_centers(dataset: np.array, k: int = 1): + """From a given dataset, plot some samples that have k centers in them.""" + images = [] + with ProcessPoolExecutor(max_workers=8) as executor: + for centers, image in executor.map( + detect_circle_centers, dataset, repeat((BLUE, GREEN)), chunksize=8 + ): + if len(centers) == k: + images.append(image) + + print("num samples: ", len(images)) + + images = np.stack(images, axis=0) + show_images(images, f"{k} circles", grid_size=8) + + +def plot_bad_centers(dataset: np.array, threshold: float = 18.0): + """From a given dataset, plot some samples that violate the distance constraint.""" + generated_centers = [] + num_bad_samples = 0 + with ProcessPoolExecutor(max_workers=8) as executor: + for centers, _ in executor.map( + detect_circle_centers, dataset, repeat((BLUE, GREEN)), chunksize=8 + ): + if len(centers) == 2: + generated_centers.append(np.array(centers)) + else: + generated_centers.append(np.zeros((2, 2))) + num_bad_samples += 1 + + if num_bad_samples > 0: + print("num bad samples in generated: ", num_bad_samples) + + generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2) + generated_distances = np.sqrt( + np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0]) + + np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1]) + ).squeeze() + + mask = generated_distances > threshold + generated = dataset[mask] + show_images(generated) + + +def plot_more_than_2_centers(generated: np.array): + nb_samples = 0 + images = [] + with ProcessPoolExecutor(max_workers=16) as executor: + for centers, image in executor.map( + detect_circle_centers, generated, ((BLUE, GREEN)), chunksize=8 + ): + if len(centers) == 2: + images.append(image) + nb_samples += 1 + + print("Nb images multiple centers: ", nb_samples) + + show_images(generated) + + +def show_images(images, title="", grid_size: int = 4): + """Display a batch of images in a grid""" + for idx in range(min(grid_size**2, len(images))): + ax = plt.subplot(grid_size, grid_size, idx + 1) + ax.imshow(images[idx]) + centers, _ = detect_circle_centers(images[idx], (BLUE, GREEN)) + for center in centers: + ax.scatter(center[1], center[0], c="red", s=20) + ax.axis("off") + plt.suptitle(title) + plt.savefig("media/circles-predicted.png") + plt.close() + + +def visualize_samples( + dataset: np.array, output_dir: str, filename: str = "sample_grid.png" +): + """From a given dataset, visualize some samples""" + + # Define the grid size (e.g., 5x5) + grid_size = 5 + fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10)) + + for i in range(grid_size): + for j in range(grid_size): + idx = i * grid_size + j + if idx < len(dataset): + img = dataset[idx] + axes[i, j].imshow(img) + centers, _ = detect_circle_centers(img, (BLUE, GREEN)) + # Plot each detected center in red. Note that center_of_mass returns (row, col) + for center in centers: + axes[i, j].scatter(center[1], center[0], c="red", s=20) + axes[i, j].axis("off") + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, filename)) + plt.close() + + +def generate_synthetic_samples( + model, + image_size: int, + nb_synthetic_samples: int, + params, + device="cuda", +): + model.eval() + generated = np.empty( + shape=(nb_synthetic_samples, image_size, image_size, 3), dtype=np.uint8 + ) + chunk = 128 + for i in tqdm( + range(nb_synthetic_samples // chunk), desc="Generating synthetic data." + ): + generated_images = ddpm_sample_images( + model=model, + image_size=image_size, + batch_size=chunk, + channels=3, + device=device, + params=params, + ) + generated_images = torch.permute(generated_images, (0, 2, 3, 1)) + generated[i * chunk : (i + 1) * chunk] = ( + (generated_images * 255.0) + .clip(min=0.0, max=255.0) + .cpu() + .numpy() + .astype(np.uint8) + ) + + return generated diff --git a/src/ddpm/values.py b/src/ddpm/values.py new file mode 100644 index 0000000..2a39023 --- /dev/null +++ b/src/ddpm/values.py @@ -0,0 +1,29 @@ +RED = (0xCC, 0x24, 0x1D) +GREEN = (0x98, 0x97, 0x1A) +BLUE = (0x45, 0x85, 0x88) +WHITE = (0xFB, 0xF1, 0xC7) +BACKGROUND = (0x50, 0x49, 0x45) + +EXPERIMENTS = { + "vanilla": { + "image_size": 32, + "distance": 11, + "delta": 4, + "radius": 3, + "rectangle_thickness": 0, + }, + "rectangle": { + "image_size": 32, + "distance": 11, + "delta": 4, + "radius": 3, + "rectangle_thickness": 4, + }, + "bigrectangle": { + "image_size": 32, + "distance": 11, + "delta": 4, + "radius": 3, + "rectangle_thickness": 8, + }, +}