|
|
|
@ -1,9 +1,11 @@ |
|
|
|
import argparse |
|
|
|
import torch |
|
|
|
import matplotlib |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import numpy as np |
|
|
|
from train import get_diffusion_params |
|
|
|
from train import TIMESTEPS, IMAGE_SIZE |
|
|
|
from diffusion import TIMESTEPS |
|
|
|
from values import EXPERIMENTS |
|
|
|
from model import UNet |
|
|
|
|
|
|
|
from generate_circle_dataset import ( |
|
|
|
@ -12,12 +14,13 @@ from generate_circle_dataset import ( |
|
|
|
|
|
|
|
from utils import compute_statistics, generate_synthetic_samples |
|
|
|
|
|
|
|
IMAGE_SIZE = 32 |
|
|
|
matplotlib.use("Agg") |
|
|
|
|
|
|
|
torch.manual_seed(1) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
GENERATE_SYNTHETIC_DATA = True |
|
|
|
GENERATE_SYNTHETIC_DATA = False |
|
|
|
|
|
|
|
|
|
|
|
def load_dataset(data_path: str): |
|
|
|
@ -27,19 +30,19 @@ def load_dataset(data_path: str): |
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
|
|
def main(args): |
|
|
|
plt.figure(figsize=(10, 10)) |
|
|
|
|
|
|
|
data = load_dataset("./data/circle_dataset/data-bigrectangle.npy") |
|
|
|
data = load_dataset(f"./data/circle_dataset/data-{args.experiment_name}.npy") |
|
|
|
|
|
|
|
if GENERATE_SYNTHETIC_DATA: |
|
|
|
nb_synthetic_samples = 102_400 |
|
|
|
params = get_diffusion_params(TIMESTEPS, device, eta=0.0) |
|
|
|
|
|
|
|
model = UNet(IMAGE_SIZE, TIMESTEPS).to(device) |
|
|
|
model.load_state_dict(torch.load("model-bigrectangle.pkl", weights_only=True)) |
|
|
|
model.load_state_dict( |
|
|
|
torch.load(f"model-{args.experiment_name}.pkl", weights_only=True) |
|
|
|
) |
|
|
|
|
|
|
|
model.eval() |
|
|
|
generated = generate_synthetic_samples( |
|
|
|
@ -49,12 +52,34 @@ if __name__ == "__main__": |
|
|
|
params=params, |
|
|
|
device=device, |
|
|
|
) |
|
|
|
np.save("./data/circle_dataset/generated-bigrectangle.npy", generated) |
|
|
|
np.save( |
|
|
|
f"./data/circle_dataset/generated-{args.experiment_name}.npy", generated |
|
|
|
) |
|
|
|
else: |
|
|
|
generated = np.load("./data/circle_dataset/generated-bigrectangle.npy") |
|
|
|
generated = np.load( |
|
|
|
f"./data/circle_dataset/generated-{args.experiment_name}.npy" |
|
|
|
) |
|
|
|
|
|
|
|
visualize_samples(generated, "") |
|
|
|
compute_statistics(data, generated) |
|
|
|
# plot_k_centers(generated, 1) |
|
|
|
# plot_bad_centers(generated) |
|
|
|
# plot_more_than_2_centers(generated) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
|
desc="Generate samples from a trained diffusion model and save to disk." |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
"-e", |
|
|
|
"--experiment_name", |
|
|
|
default="vanilla", |
|
|
|
type=str, |
|
|
|
choices=EXPERIMENTS.keys(), |
|
|
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
main(args) |
|
|
|
|