6 changed files with 761 additions and 530 deletions
@ -0,0 +1,176 @@ |
|||
from typing import Dict |
|||
import torch |
|||
from tqdm import tqdm |
|||
|
|||
TIMESTEPS = 1000 |
|||
DDIM_TIMESTEPS = 500 |
|||
|
|||
|
|||
@torch.compile |
|||
@torch.no_grad() |
|||
def ddpm_sample( |
|||
model: torch.nn.Module, |
|||
x: torch.Tensor, |
|||
t: torch.Tensor, |
|||
params: Dict[str, torch.Tensor], |
|||
) -> torch.Tensor: |
|||
"""Sample from the model at timestep t""" |
|||
predicted_noise = model(x, t) |
|||
|
|||
one_over_alphas = extract(params["one_over_alphas"], t, x.shape) |
|||
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) |
|||
|
|||
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) |
|||
|
|||
posterior_variance = extract(params["posterior_variance"], t, x.shape) |
|||
|
|||
if t[0] > 0: |
|||
noise = torch.randn_like(x) |
|||
return pred_mean + torch.sqrt(posterior_variance) * noise |
|||
else: |
|||
return pred_mean |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def ddim_sample( |
|||
model: torch.nn.Module, |
|||
x: torch.Tensor, |
|||
t: torch.Tensor, |
|||
params: Dict[str, torch.Tensor], |
|||
device, |
|||
) -> torch.Tensor: |
|||
"""Sample from the model in a non-markovian way (DDIM)""" |
|||
device = next(model.parameters()).device |
|||
|
|||
stride = TIMESTEPS // DDIM_TIMESTEPS |
|||
t_prev = t - stride |
|||
predicted_noise = model(x, t) |
|||
|
|||
alphas_prod = extract(params["alphas_cumprod"], t, x.shape) |
|||
valid_mask = (t_prev >= 0).view(-1, 1, 1, 1) |
|||
safe_t_prev = torch.maximum(t_prev, torch.tensor(0, device=device)) |
|||
alphas_prod_prev = extract(params["alphas_cumprod"], safe_t_prev, x.shape) |
|||
alphas_prod_prev = torch.where( |
|||
valid_mask, alphas_prod_prev, torch.ones_like(alphas_prod_prev) |
|||
) |
|||
|
|||
sigma = extract(params["ddim_sigma"], t, x.shape) |
|||
|
|||
pred_x0 = (x - (1 - alphas_prod).sqrt() * predicted_noise) / alphas_prod.sqrt() |
|||
|
|||
pred = ( |
|||
alphas_prod_prev.sqrt() * pred_x0 |
|||
+ (1.0 - alphas_prod_prev).sqrt() * predicted_noise |
|||
) |
|||
|
|||
if t[0] > 0: |
|||
noise = torch.randn_like(x) |
|||
pred = pred + noise * sigma |
|||
|
|||
return pred |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def ddpm_sample_images( |
|||
model: torch.nn.Module, |
|||
image_size: int, |
|||
batch_size: int, |
|||
channels: int, |
|||
device: torch.device, |
|||
params: Dict[str, torch.Tensor], |
|||
): |
|||
"""Generate new images using the trained model""" |
|||
x = torch.randn(batch_size, channels, image_size, image_size).to(device) |
|||
|
|||
for t in tqdm( |
|||
reversed(range(TIMESTEPS)), desc="DDPM Sampling", total=TIMESTEPS, leave=False |
|||
): |
|||
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) |
|||
x = ddpm_sample(model, x, t_batch, params) |
|||
|
|||
if x.isnan().any(): |
|||
raise ValueError(f"NaN detected in image at timestep {t}") |
|||
|
|||
return x |
|||
|
|||
|
|||
def get_ddim_timesteps( |
|||
total_timesteps: int, num_sampling_timesteps: int |
|||
) -> torch.Tensor: |
|||
"""Gets the timesteps used for the DDIM process.""" |
|||
assert total_timesteps % num_sampling_timesteps == 0 |
|||
stride = total_timesteps // num_sampling_timesteps |
|||
timesteps = torch.arange(0, total_timesteps, stride) |
|||
return timesteps.flip(0) |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def ddim_sample_images( |
|||
model: torch.nn.Module, |
|||
image_size: int, |
|||
batch_size: int, |
|||
channels: int, |
|||
device: torch.device, |
|||
params: Dict[str, torch.Tensor], |
|||
): |
|||
"""Generate new images using the trained model""" |
|||
x = torch.randn(batch_size, channels, image_size, image_size).to(device) |
|||
|
|||
timesteps = get_ddim_timesteps(TIMESTEPS, DDIM_TIMESTEPS) |
|||
|
|||
for i in tqdm(range(len(timesteps) - 1), desc="DDIM Sampling"): |
|||
t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long) |
|||
x = ddim_sample(model, x, t, params) |
|||
|
|||
if x.isnan().any(): |
|||
raise ValueError(f"NaN detected at timestep {timesteps[i]}") |
|||
|
|||
return x |
|||
|
|||
|
|||
def extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Tensor.shape): |
|||
"""Extract coefficients at specified timesteps t""" |
|||
batch_size = t.shape[0] |
|||
out = a.gather(-1, t) |
|||
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) |
|||
|
|||
|
|||
def get_diffusion_params( |
|||
timesteps: int, |
|||
device: torch.device, |
|||
ddim_timesteps: int = DDIM_TIMESTEPS, |
|||
eta=0.0, |
|||
) -> Dict[str, torch.Tensor]: |
|||
def linear_beta_schedule(timesteps): |
|||
beta_start = 0.0001 |
|||
beta_end = 0.02 |
|||
return torch.linspace(beta_start, beta_end, timesteps) |
|||
|
|||
betas = linear_beta_schedule(timesteps) |
|||
alphas = 1.0 - betas |
|||
alphas_cumprod = torch.cumprod(alphas, dim=0) |
|||
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) |
|||
|
|||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) |
|||
|
|||
one_over_alphas = 1.0 / torch.sqrt(alphas) |
|||
posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod |
|||
|
|||
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
|||
|
|||
ddim_sigma = eta * torch.sqrt( |
|||
(1.0 - alphas_cumprod_prev) |
|||
/ (1.0 - alphas_cumprod) |
|||
* (1 - alphas_cumprod / alphas_cumprod_prev) |
|||
) |
|||
|
|||
return { |
|||
# DDPM Parameters |
|||
"betas": betas.to(device), |
|||
"alphas_cumprod": alphas_cumprod.to(device), |
|||
"posterior_variance": posterior_variance.to(device), |
|||
"one_over_alphas": one_over_alphas.to(device), |
|||
"posterior_mean_coef": posterior_mean_coef.to(device), |
|||
# DDIM Parameters |
|||
"ddim_sigma": ddim_sigma.to(device), |
|||
} |
|||
@ -0,0 +1,329 @@ |
|||
import os |
|||
import torch |
|||
from concurrent.futures import ProcessPoolExecutor |
|||
from itertools import repeat |
|||
|
|||
from scipy.ndimage import label, center_of_mass |
|||
import numpy as np |
|||
from tqdm import tqdm |
|||
from values import GREEN, BLUE |
|||
|
|||
import matplotlib |
|||
import matplotlib.pyplot as plt |
|||
import seaborn as sns |
|||
|
|||
from diffusion import ddpm_sample_images |
|||
|
|||
matplotlib.use("Agg") |
|||
|
|||
|
|||
def detect_circle_centers(image, circle_colors, threshold=30, min_pixels=12): |
|||
""" |
|||
Detects centers of circles in an image based on their known colors, filtering out |
|||
regions that have less than a specified number of pixels. |
|||
|
|||
This function creates a mask for each provided circle color by selecting pixels |
|||
whose RGB values are close to the target color (within a per-channel threshold). |
|||
It then labels the connected regions in the mask, filters out small regions based on |
|||
the min_pixels parameter, and computes the centers of the remaining regions. |
|||
|
|||
Args: |
|||
image (np.ndarray): The image array with shape (H, W, 3). |
|||
circle_colors (list of tuple): List of RGB tuples for the circle colors to detect, |
|||
e.g. [GREEN, BLUE]. |
|||
threshold (int): Maximum allowed difference per channel between a pixel and the |
|||
target circle color. |
|||
min_pixels (int): Minimum number of pixels for a region to be considered valid. |
|||
|
|||
Returns: |
|||
centers (list of tuples): List of (row, col) coordinates for each detected circle. |
|||
image (np.ndarray): The original image. |
|||
""" |
|||
centers = [] |
|||
# Loop over each target circle color. |
|||
for color in circle_colors: |
|||
# Compute absolute difference between each pixel and the target color. |
|||
diff = np.abs(image.astype(np.int16) - np.array(color, dtype=np.int16)) |
|||
# Create a mask: pixels where all channels are within the threshold. |
|||
mask = np.all(diff < threshold, axis=-1) |
|||
|
|||
# Label connected regions in the mask. |
|||
labeled, num_features = label(mask) |
|||
|
|||
# Process each labeled region. |
|||
for i in range(1, num_features + 1): |
|||
# Count the number of pixels in the current region. |
|||
region_size = np.sum(labeled == i) |
|||
# Skip regions that are smaller than the minimum required. |
|||
if region_size < min_pixels: |
|||
continue |
|||
|
|||
center = center_of_mass(mask, labeled, i) |
|||
centers.append(center) |
|||
|
|||
return centers, image |
|||
|
|||
|
|||
def compute_statistics( |
|||
data: np.array, |
|||
generated: np.array, |
|||
output_path: str = "center_statistics.png", |
|||
): |
|||
assert len(data.shape) == 4 |
|||
assert data.shape[2] == data.shape[3] |
|||
image_size = data.shape[2] |
|||
|
|||
data_centers = [] |
|||
num_bad_samples = 0 |
|||
with ProcessPoolExecutor(max_workers=8) as executor: |
|||
for centers, _ in executor.map( |
|||
detect_circle_centers, data, repeat((BLUE, GREEN)), chunksize=8 |
|||
): |
|||
if len(centers) == 2: |
|||
data_centers.append(np.array(centers)) |
|||
else: |
|||
num_bad_samples += 1 |
|||
|
|||
if num_bad_samples > 0: |
|||
print("num bad samples in data: ", num_bad_samples) |
|||
|
|||
data_centers = np.stack(data_centers, axis=0) # (num_samples, 2, 2) |
|||
|
|||
num_bad_samples = 0 |
|||
generated_centers = [] |
|||
with ProcessPoolExecutor(max_workers=16) as executor: |
|||
for centers, _ in executor.map( |
|||
detect_circle_centers, generated, repeat((BLUE, GREEN)), chunksize=8 |
|||
): |
|||
if len(centers) == 2: |
|||
generated_centers.append(np.array(centers)) |
|||
else: |
|||
num_bad_samples += 1 |
|||
|
|||
if num_bad_samples > 0: |
|||
print("num bad samples in generated: ", num_bad_samples) |
|||
|
|||
generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2) |
|||
|
|||
# Create a figure with subplots |
|||
plt.figure(figsize=(15, 10)) |
|||
|
|||
# Plot histogram of x positions |
|||
plt.subplot(2, 2, 1) |
|||
sns.histplot( |
|||
data_centers[:, :, 0].reshape(-1), |
|||
color="blue", |
|||
label="Data", |
|||
kde=True, |
|||
stat="density", |
|||
) |
|||
sns.histplot( |
|||
generated_centers[:, :, 0].reshape(-1), |
|||
color="orange", |
|||
label="Generated", |
|||
kde=True, |
|||
stat="density", |
|||
) |
|||
plt.title("X Position Distribution") |
|||
plt.xlabel("X Position") |
|||
plt.legend() |
|||
|
|||
# Plot histogram of y positions |
|||
plt.subplot(2, 2, 2) |
|||
sns.histplot( |
|||
data_centers[:, :, 1].reshape(-1), |
|||
color="blue", |
|||
label="Data", |
|||
kde=True, |
|||
stat="density", |
|||
) |
|||
sns.histplot( |
|||
generated_centers[:, :, 1].reshape(-1), |
|||
color="orange", |
|||
label="Generated", |
|||
kde=True, |
|||
stat="density", |
|||
) |
|||
plt.title("Y Position Distribution") |
|||
plt.xlabel("Y Position") |
|||
plt.legend() |
|||
|
|||
# Plot histogram of distances |
|||
plt.subplot(2, 2, 3) |
|||
distances = np.sqrt( |
|||
np.square(data_centers[:, ::2, 0] - data_centers[:, 1::2, 0]) |
|||
+ np.square(data_centers[:, ::2, 1] - data_centers[:, 1::2, 1]) |
|||
).squeeze() |
|||
generated_distances = np.sqrt( |
|||
np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0]) |
|||
+ np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1]) |
|||
).squeeze() |
|||
sns.histplot(distances, color="blue", label="Data", kde=True, stat="density") |
|||
sns.histplot( |
|||
generated_distances, color="orange", label="Generated", kde=True, stat="density" |
|||
) |
|||
plt.title("Distance between circles distribution") |
|||
plt.xlabel("Distance") |
|||
plt.legend() |
|||
|
|||
# Plot 2D heatmap of center positions |
|||
plt.subplot(2, 2, 4) |
|||
sns.kdeplot( |
|||
x=data_centers[:, :, 0].reshape(-1), |
|||
y=data_centers[:, :, 1].reshape(-1), |
|||
cmap="Blues", |
|||
label="Data", |
|||
) |
|||
sns.kdeplot( |
|||
x=generated_centers[:, :, 0].reshape(-1), |
|||
y=generated_centers[:, :, 1].reshape(-1), |
|||
cmap="Oranges", |
|||
label="Generated", |
|||
) |
|||
plt.title("2D Heatmap of Center Positions") |
|||
plt.xlabel("X Position") |
|||
plt.ylabel("Y Position") |
|||
plt.xlim(0, image_size) |
|||
plt.ylim(0, image_size) |
|||
plt.legend() |
|||
|
|||
plt.tight_layout() |
|||
output_path = os.path.join("media", output_path) |
|||
plt.savefig(output_path) |
|||
print(f"Saved histograms at {output_path}") |
|||
plt.close() |
|||
|
|||
|
|||
def plot_k_centers(dataset: np.array, k: int = 1): |
|||
"""From a given dataset, plot some samples that have k centers in them.""" |
|||
images = [] |
|||
with ProcessPoolExecutor(max_workers=8) as executor: |
|||
for centers, image in executor.map( |
|||
detect_circle_centers, dataset, repeat((BLUE, GREEN)), chunksize=8 |
|||
): |
|||
if len(centers) == k: |
|||
images.append(image) |
|||
|
|||
print("num samples: ", len(images)) |
|||
|
|||
images = np.stack(images, axis=0) |
|||
show_images(images, f"{k} circles", grid_size=8) |
|||
|
|||
|
|||
def plot_bad_centers(dataset: np.array, threshold: float = 18.0): |
|||
"""From a given dataset, plot some samples that violate the distance constraint.""" |
|||
generated_centers = [] |
|||
num_bad_samples = 0 |
|||
with ProcessPoolExecutor(max_workers=8) as executor: |
|||
for centers, _ in executor.map( |
|||
detect_circle_centers, dataset, repeat((BLUE, GREEN)), chunksize=8 |
|||
): |
|||
if len(centers) == 2: |
|||
generated_centers.append(np.array(centers)) |
|||
else: |
|||
generated_centers.append(np.zeros((2, 2))) |
|||
num_bad_samples += 1 |
|||
|
|||
if num_bad_samples > 0: |
|||
print("num bad samples in generated: ", num_bad_samples) |
|||
|
|||
generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2) |
|||
generated_distances = np.sqrt( |
|||
np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0]) |
|||
+ np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1]) |
|||
).squeeze() |
|||
|
|||
mask = generated_distances > threshold |
|||
generated = dataset[mask] |
|||
show_images(generated) |
|||
|
|||
|
|||
def plot_more_than_2_centers(generated: np.array): |
|||
nb_samples = 0 |
|||
images = [] |
|||
with ProcessPoolExecutor(max_workers=16) as executor: |
|||
for centers, image in executor.map( |
|||
detect_circle_centers, generated, ((BLUE, GREEN)), chunksize=8 |
|||
): |
|||
if len(centers) == 2: |
|||
images.append(image) |
|||
nb_samples += 1 |
|||
|
|||
print("Nb images multiple centers: ", nb_samples) |
|||
|
|||
show_images(generated) |
|||
|
|||
|
|||
def show_images(images, title="", grid_size: int = 4): |
|||
"""Display a batch of images in a grid""" |
|||
for idx in range(min(grid_size**2, len(images))): |
|||
ax = plt.subplot(grid_size, grid_size, idx + 1) |
|||
ax.imshow(images[idx]) |
|||
centers, _ = detect_circle_centers(images[idx], (BLUE, GREEN)) |
|||
for center in centers: |
|||
ax.scatter(center[1], center[0], c="red", s=20) |
|||
ax.axis("off") |
|||
plt.suptitle(title) |
|||
plt.savefig("media/circles-predicted.png") |
|||
plt.close() |
|||
|
|||
|
|||
def visualize_samples( |
|||
dataset: np.array, output_dir: str, filename: str = "sample_grid.png" |
|||
): |
|||
"""From a given dataset, visualize some samples""" |
|||
|
|||
# Define the grid size (e.g., 5x5) |
|||
grid_size = 5 |
|||
fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10)) |
|||
|
|||
for i in range(grid_size): |
|||
for j in range(grid_size): |
|||
idx = i * grid_size + j |
|||
if idx < len(dataset): |
|||
img = dataset[idx] |
|||
axes[i, j].imshow(img) |
|||
centers, _ = detect_circle_centers(img, (BLUE, GREEN)) |
|||
# Plot each detected center in red. Note that center_of_mass returns (row, col) |
|||
for center in centers: |
|||
axes[i, j].scatter(center[1], center[0], c="red", s=20) |
|||
axes[i, j].axis("off") |
|||
|
|||
plt.tight_layout() |
|||
plt.savefig(os.path.join(output_dir, filename)) |
|||
plt.close() |
|||
|
|||
|
|||
def generate_synthetic_samples( |
|||
model, |
|||
image_size: int, |
|||
nb_synthetic_samples: int, |
|||
params, |
|||
device="cuda", |
|||
): |
|||
model.eval() |
|||
generated = np.empty( |
|||
shape=(nb_synthetic_samples, image_size, image_size, 3), dtype=np.uint8 |
|||
) |
|||
chunk = 128 |
|||
for i in tqdm( |
|||
range(nb_synthetic_samples // chunk), desc="Generating synthetic data." |
|||
): |
|||
generated_images = ddpm_sample_images( |
|||
model=model, |
|||
image_size=image_size, |
|||
batch_size=chunk, |
|||
channels=3, |
|||
device=device, |
|||
params=params, |
|||
) |
|||
generated_images = torch.permute(generated_images, (0, 2, 3, 1)) |
|||
generated[i * chunk : (i + 1) * chunk] = ( |
|||
(generated_images * 255.0) |
|||
.clip(min=0.0, max=255.0) |
|||
.cpu() |
|||
.numpy() |
|||
.astype(np.uint8) |
|||
) |
|||
|
|||
return generated |
|||
@ -0,0 +1,29 @@ |
|||
RED = (0xCC, 0x24, 0x1D) |
|||
GREEN = (0x98, 0x97, 0x1A) |
|||
BLUE = (0x45, 0x85, 0x88) |
|||
WHITE = (0xFB, 0xF1, 0xC7) |
|||
BACKGROUND = (0x50, 0x49, 0x45) |
|||
|
|||
EXPERIMENTS = { |
|||
"vanilla": { |
|||
"image_size": 32, |
|||
"distance": 11, |
|||
"delta": 4, |
|||
"radius": 3, |
|||
"rectangle_thickness": 0, |
|||
}, |
|||
"rectangle": { |
|||
"image_size": 32, |
|||
"distance": 11, |
|||
"delta": 4, |
|||
"radius": 3, |
|||
"rectangle_thickness": 4, |
|||
}, |
|||
"bigrectangle": { |
|||
"image_size": 32, |
|||
"distance": 11, |
|||
"delta": 4, |
|||
"radius": 3, |
|||
"rectangle_thickness": 8, |
|||
}, |
|||
} |
|||
Loading…
Reference in new issue