diff --git a/src/ddpm/sample.py b/src/ddpm/sample.py index 4ed1f5f..46c498c 100644 --- a/src/ddpm/sample.py +++ b/src/ddpm/sample.py @@ -1,9 +1,11 @@ +import argparse import torch import matplotlib import matplotlib.pyplot as plt import numpy as np from train import get_diffusion_params -from train import TIMESTEPS, IMAGE_SIZE +from diffusion import TIMESTEPS +from values import EXPERIMENTS from model import UNet from generate_circle_dataset import ( @@ -12,12 +14,13 @@ from generate_circle_dataset import ( from utils import compute_statistics, generate_synthetic_samples +IMAGE_SIZE = 32 matplotlib.use("Agg") torch.manual_seed(1) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -GENERATE_SYNTHETIC_DATA = True +GENERATE_SYNTHETIC_DATA = False def load_dataset(data_path: str): @@ -27,19 +30,19 @@ def load_dataset(data_path: str): return data -if __name__ == "__main__": - torch.set_float32_matmul_precision("high") - +def main(args): plt.figure(figsize=(10, 10)) - data = load_dataset("./data/circle_dataset/data-bigrectangle.npy") + data = load_dataset(f"./data/circle_dataset/data-{args.experiment_name}.npy") if GENERATE_SYNTHETIC_DATA: nb_synthetic_samples = 102_400 params = get_diffusion_params(TIMESTEPS, device, eta=0.0) model = UNet(IMAGE_SIZE, TIMESTEPS).to(device) - model.load_state_dict(torch.load("model-bigrectangle.pkl", weights_only=True)) + model.load_state_dict( + torch.load(f"model-{args.experiment_name}.pkl", weights_only=True) + ) model.eval() generated = generate_synthetic_samples( @@ -49,12 +52,34 @@ if __name__ == "__main__": params=params, device=device, ) - np.save("./data/circle_dataset/generated-bigrectangle.npy", generated) + np.save( + f"./data/circle_dataset/generated-{args.experiment_name}.npy", generated + ) else: - generated = np.load("./data/circle_dataset/generated-bigrectangle.npy") + generated = np.load( + f"./data/circle_dataset/generated-{args.experiment_name}.npy" + ) visualize_samples(generated, "") compute_statistics(data, generated) # plot_k_centers(generated, 1) # plot_bad_centers(generated) # plot_more_than_2_centers(generated) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + parser = argparse.ArgumentParser( + desc="Generate samples from a trained diffusion model and save to disk." + ) + parser.add_argument( + "-e", + "--experiment_name", + default="vanilla", + type=str, + choices=EXPERIMENTS.keys(), + ) + args = parser.parse_args() + + main(args) diff --git a/src/ddpm/train.py b/src/ddpm/train.py index a48dcff..5454dcf 100644 --- a/src/ddpm/train.py +++ b/src/ddpm/train.py @@ -21,19 +21,28 @@ BATCH_SIZE = 512 CHANNELS = 3 # Histogram generation -HIST_GENERATED_SAMPLES = 10_000 +HIST_GENERATED_SAMPLES = 10_112 HIST_DATASET_SAMPLES = ( 100_000 # Using the whole dataset becomes super slow, so we take a subset ) -def load_dataset(data_path: str): +def load_dataset(data_path: str, hist_set: int = 0): print("Loading dataset... ", end="", flush=True) - data = np.load(data_path).astype(np.float32) - data = data / 255.0 # normalize between [0-1] + data = np.load(data_path) + + data_hist = None + if hist_set > 0: + hist_ids = np.random.choice( + np.arange(hist_set), replace=False, size=(hist_set,) + ) + data_hist = data[hist_ids] + + data = data.astype(np.float32) / 255.0 # normalize between [0-1] data = np.permute_dims(data, (0, 3, 1, 2)) # (b, h, w, c) -> (b, c, h, w) print("Done.") - return data + + return data, data_hist def create_dataset_loader(data: np.array, batch_size: int): @@ -111,11 +120,11 @@ def main(args): params = get_diffusion_params(TIMESTEPS, device) loss_fn = get_loss_fn(model, params) - data = load_dataset(f"./data/circle_dataset/data-{args.experiment_name}.npy") - - if args.histogram: - ids = np.random.choice(np.arange(data.shape[0]), replace=False) - data_hist_set = data[ids] + data_hist_samples = HIST_DATASET_SAMPLES if args.histogram else None + data, data_hist = load_dataset( + f"./data/circle_dataset/data-{args.experiment_name}.npy", + data_hist_samples, + ) # Main training loop train_losses = [] @@ -133,7 +142,7 @@ def main(args): device=device, ) compute_statistics( - data_hist_set, + data_hist, generated, output_path=f"center_statistics_{args.experiment_name}_{e}.png", ) diff --git a/src/ddpm/utils.py b/src/ddpm/utils.py index 85a4f4a..9ea0960 100644 --- a/src/ddpm/utils.py +++ b/src/ddpm/utils.py @@ -39,6 +39,11 @@ def detect_circle_centers(image, circle_colors, threshold=30, min_pixels=12): centers (list of tuples): List of (row, col) coordinates for each detected circle. image (np.ndarray): The original image. """ + + image_orig = image + if image.shape[0] == 3: + image = np.permute_dims(image, (1, 2, 0)) + centers = [] # Loop over each target circle color. for color in circle_colors: @@ -61,7 +66,7 @@ def detect_circle_centers(image, circle_colors, threshold=30, min_pixels=12): center = center_of_mass(mask, labeled, i) centers.append(center) - return centers, image + return centers, image_orig def compute_statistics( @@ -69,8 +74,6 @@ def compute_statistics( generated: np.array, output_path: str = "center_statistics.png", ): - assert len(data.shape) == 4 - assert data.shape[2] == data.shape[3] image_size = data.shape[2] data_centers = []