Browse Source

fix: refactor bug

master
CALVO GONZALEZ Ramon 10 months ago
parent
commit
340b4345de
  1. 43
      src/ddpm/sample.py
  2. 31
      src/ddpm/train.py
  3. 9
      src/ddpm/utils.py

43
src/ddpm/sample.py

@ -1,9 +1,11 @@
import argparse
import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from train import get_diffusion_params
from train import TIMESTEPS, IMAGE_SIZE
from diffusion import TIMESTEPS
from values import EXPERIMENTS
from model import UNet
from generate_circle_dataset import (
@ -12,12 +14,13 @@ from generate_circle_dataset import (
from utils import compute_statistics, generate_synthetic_samples
IMAGE_SIZE = 32
matplotlib.use("Agg")
torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
GENERATE_SYNTHETIC_DATA = True
GENERATE_SYNTHETIC_DATA = False
def load_dataset(data_path: str):
@ -27,19 +30,19 @@ def load_dataset(data_path: str):
return data
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
def main(args):
plt.figure(figsize=(10, 10))
data = load_dataset("./data/circle_dataset/data-bigrectangle.npy")
data = load_dataset(f"./data/circle_dataset/data-{args.experiment_name}.npy")
if GENERATE_SYNTHETIC_DATA:
nb_synthetic_samples = 102_400
params = get_diffusion_params(TIMESTEPS, device, eta=0.0)
model = UNet(IMAGE_SIZE, TIMESTEPS).to(device)
model.load_state_dict(torch.load("model-bigrectangle.pkl", weights_only=True))
model.load_state_dict(
torch.load(f"model-{args.experiment_name}.pkl", weights_only=True)
)
model.eval()
generated = generate_synthetic_samples(
@ -49,12 +52,34 @@ if __name__ == "__main__":
params=params,
device=device,
)
np.save("./data/circle_dataset/generated-bigrectangle.npy", generated)
np.save(
f"./data/circle_dataset/generated-{args.experiment_name}.npy", generated
)
else:
generated = np.load("./data/circle_dataset/generated-bigrectangle.npy")
generated = np.load(
f"./data/circle_dataset/generated-{args.experiment_name}.npy"
)
visualize_samples(generated, "")
compute_statistics(data, generated)
# plot_k_centers(generated, 1)
# plot_bad_centers(generated)
# plot_more_than_2_centers(generated)
if __name__ == "__main__":
torch.set_float32_matmul_precision("high")
parser = argparse.ArgumentParser(
desc="Generate samples from a trained diffusion model and save to disk."
)
parser.add_argument(
"-e",
"--experiment_name",
default="vanilla",
type=str,
choices=EXPERIMENTS.keys(),
)
args = parser.parse_args()
main(args)

31
src/ddpm/train.py

@ -21,19 +21,28 @@ BATCH_SIZE = 512
CHANNELS = 3
# Histogram generation
HIST_GENERATED_SAMPLES = 10_000
HIST_GENERATED_SAMPLES = 10_112
HIST_DATASET_SAMPLES = (
100_000 # Using the whole dataset becomes super slow, so we take a subset
)
def load_dataset(data_path: str):
def load_dataset(data_path: str, hist_set: int = 0):
print("Loading dataset... ", end="", flush=True)
data = np.load(data_path).astype(np.float32)
data = data / 255.0 # normalize between [0-1]
data = np.load(data_path)
data_hist = None
if hist_set > 0:
hist_ids = np.random.choice(
np.arange(hist_set), replace=False, size=(hist_set,)
)
data_hist = data[hist_ids]
data = data.astype(np.float32) / 255.0 # normalize between [0-1]
data = np.permute_dims(data, (0, 3, 1, 2)) # (b, h, w, c) -> (b, c, h, w)
print("Done.")
return data
return data, data_hist
def create_dataset_loader(data: np.array, batch_size: int):
@ -111,11 +120,11 @@ def main(args):
params = get_diffusion_params(TIMESTEPS, device)
loss_fn = get_loss_fn(model, params)
data = load_dataset(f"./data/circle_dataset/data-{args.experiment_name}.npy")
if args.histogram:
ids = np.random.choice(np.arange(data.shape[0]), replace=False)
data_hist_set = data[ids]
data_hist_samples = HIST_DATASET_SAMPLES if args.histogram else None
data, data_hist = load_dataset(
f"./data/circle_dataset/data-{args.experiment_name}.npy",
data_hist_samples,
)
# Main training loop
train_losses = []
@ -133,7 +142,7 @@ def main(args):
device=device,
)
compute_statistics(
data_hist_set,
data_hist,
generated,
output_path=f"center_statistics_{args.experiment_name}_{e}.png",
)

9
src/ddpm/utils.py

@ -39,6 +39,11 @@ def detect_circle_centers(image, circle_colors, threshold=30, min_pixels=12):
centers (list of tuples): List of (row, col) coordinates for each detected circle.
image (np.ndarray): The original image.
"""
image_orig = image
if image.shape[0] == 3:
image = np.permute_dims(image, (1, 2, 0))
centers = []
# Loop over each target circle color.
for color in circle_colors:
@ -61,7 +66,7 @@ def detect_circle_centers(image, circle_colors, threshold=30, min_pixels=12):
center = center_of_mass(mask, labeled, i)
centers.append(center)
return centers, image
return centers, image_orig
def compute_statistics(
@ -69,8 +74,6 @@ def compute_statistics(
generated: np.array,
output_path: str = "center_statistics.png",
):
assert len(data.shape) == 4
assert data.shape[2] == data.shape[3]
image_size = data.shape[2]
data_centers = []

Loading…
Cancel
Save