6 changed files with 761 additions and 530 deletions
@ -0,0 +1,176 @@ |
|||||
|
from typing import Dict |
||||
|
import torch |
||||
|
from tqdm import tqdm |
||||
|
|
||||
|
TIMESTEPS = 1000 |
||||
|
DDIM_TIMESTEPS = 500 |
||||
|
|
||||
|
|
||||
|
@torch.compile |
||||
|
@torch.no_grad() |
||||
|
def ddpm_sample( |
||||
|
model: torch.nn.Module, |
||||
|
x: torch.Tensor, |
||||
|
t: torch.Tensor, |
||||
|
params: Dict[str, torch.Tensor], |
||||
|
) -> torch.Tensor: |
||||
|
"""Sample from the model at timestep t""" |
||||
|
predicted_noise = model(x, t) |
||||
|
|
||||
|
one_over_alphas = extract(params["one_over_alphas"], t, x.shape) |
||||
|
posterior_mean_coef = extract(params["posterior_mean_coef"], t, x.shape) |
||||
|
|
||||
|
pred_mean = one_over_alphas * (x - posterior_mean_coef * predicted_noise) |
||||
|
|
||||
|
posterior_variance = extract(params["posterior_variance"], t, x.shape) |
||||
|
|
||||
|
if t[0] > 0: |
||||
|
noise = torch.randn_like(x) |
||||
|
return pred_mean + torch.sqrt(posterior_variance) * noise |
||||
|
else: |
||||
|
return pred_mean |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def ddim_sample( |
||||
|
model: torch.nn.Module, |
||||
|
x: torch.Tensor, |
||||
|
t: torch.Tensor, |
||||
|
params: Dict[str, torch.Tensor], |
||||
|
device, |
||||
|
) -> torch.Tensor: |
||||
|
"""Sample from the model in a non-markovian way (DDIM)""" |
||||
|
device = next(model.parameters()).device |
||||
|
|
||||
|
stride = TIMESTEPS // DDIM_TIMESTEPS |
||||
|
t_prev = t - stride |
||||
|
predicted_noise = model(x, t) |
||||
|
|
||||
|
alphas_prod = extract(params["alphas_cumprod"], t, x.shape) |
||||
|
valid_mask = (t_prev >= 0).view(-1, 1, 1, 1) |
||||
|
safe_t_prev = torch.maximum(t_prev, torch.tensor(0, device=device)) |
||||
|
alphas_prod_prev = extract(params["alphas_cumprod"], safe_t_prev, x.shape) |
||||
|
alphas_prod_prev = torch.where( |
||||
|
valid_mask, alphas_prod_prev, torch.ones_like(alphas_prod_prev) |
||||
|
) |
||||
|
|
||||
|
sigma = extract(params["ddim_sigma"], t, x.shape) |
||||
|
|
||||
|
pred_x0 = (x - (1 - alphas_prod).sqrt() * predicted_noise) / alphas_prod.sqrt() |
||||
|
|
||||
|
pred = ( |
||||
|
alphas_prod_prev.sqrt() * pred_x0 |
||||
|
+ (1.0 - alphas_prod_prev).sqrt() * predicted_noise |
||||
|
) |
||||
|
|
||||
|
if t[0] > 0: |
||||
|
noise = torch.randn_like(x) |
||||
|
pred = pred + noise * sigma |
||||
|
|
||||
|
return pred |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def ddpm_sample_images( |
||||
|
model: torch.nn.Module, |
||||
|
image_size: int, |
||||
|
batch_size: int, |
||||
|
channels: int, |
||||
|
device: torch.device, |
||||
|
params: Dict[str, torch.Tensor], |
||||
|
): |
||||
|
"""Generate new images using the trained model""" |
||||
|
x = torch.randn(batch_size, channels, image_size, image_size).to(device) |
||||
|
|
||||
|
for t in tqdm( |
||||
|
reversed(range(TIMESTEPS)), desc="DDPM Sampling", total=TIMESTEPS, leave=False |
||||
|
): |
||||
|
t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) |
||||
|
x = ddpm_sample(model, x, t_batch, params) |
||||
|
|
||||
|
if x.isnan().any(): |
||||
|
raise ValueError(f"NaN detected in image at timestep {t}") |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
|
||||
|
def get_ddim_timesteps( |
||||
|
total_timesteps: int, num_sampling_timesteps: int |
||||
|
) -> torch.Tensor: |
||||
|
"""Gets the timesteps used for the DDIM process.""" |
||||
|
assert total_timesteps % num_sampling_timesteps == 0 |
||||
|
stride = total_timesteps // num_sampling_timesteps |
||||
|
timesteps = torch.arange(0, total_timesteps, stride) |
||||
|
return timesteps.flip(0) |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def ddim_sample_images( |
||||
|
model: torch.nn.Module, |
||||
|
image_size: int, |
||||
|
batch_size: int, |
||||
|
channels: int, |
||||
|
device: torch.device, |
||||
|
params: Dict[str, torch.Tensor], |
||||
|
): |
||||
|
"""Generate new images using the trained model""" |
||||
|
x = torch.randn(batch_size, channels, image_size, image_size).to(device) |
||||
|
|
||||
|
timesteps = get_ddim_timesteps(TIMESTEPS, DDIM_TIMESTEPS) |
||||
|
|
||||
|
for i in tqdm(range(len(timesteps) - 1), desc="DDIM Sampling"): |
||||
|
t = torch.full((batch_size,), timesteps[i], device=device, dtype=torch.long) |
||||
|
x = ddim_sample(model, x, t, params) |
||||
|
|
||||
|
if x.isnan().any(): |
||||
|
raise ValueError(f"NaN detected at timestep {timesteps[i]}") |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
|
||||
|
def extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Tensor.shape): |
||||
|
"""Extract coefficients at specified timesteps t""" |
||||
|
batch_size = t.shape[0] |
||||
|
out = a.gather(-1, t) |
||||
|
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) |
||||
|
|
||||
|
|
||||
|
def get_diffusion_params( |
||||
|
timesteps: int, |
||||
|
device: torch.device, |
||||
|
ddim_timesteps: int = DDIM_TIMESTEPS, |
||||
|
eta=0.0, |
||||
|
) -> Dict[str, torch.Tensor]: |
||||
|
def linear_beta_schedule(timesteps): |
||||
|
beta_start = 0.0001 |
||||
|
beta_end = 0.02 |
||||
|
return torch.linspace(beta_start, beta_end, timesteps) |
||||
|
|
||||
|
betas = linear_beta_schedule(timesteps) |
||||
|
alphas = 1.0 - betas |
||||
|
alphas_cumprod = torch.cumprod(alphas, dim=0) |
||||
|
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) |
||||
|
|
||||
|
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) |
||||
|
|
||||
|
one_over_alphas = 1.0 / torch.sqrt(alphas) |
||||
|
posterior_mean_coef = betas / sqrt_one_minus_alphas_cumprod |
||||
|
|
||||
|
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
||||
|
|
||||
|
ddim_sigma = eta * torch.sqrt( |
||||
|
(1.0 - alphas_cumprod_prev) |
||||
|
/ (1.0 - alphas_cumprod) |
||||
|
* (1 - alphas_cumprod / alphas_cumprod_prev) |
||||
|
) |
||||
|
|
||||
|
return { |
||||
|
# DDPM Parameters |
||||
|
"betas": betas.to(device), |
||||
|
"alphas_cumprod": alphas_cumprod.to(device), |
||||
|
"posterior_variance": posterior_variance.to(device), |
||||
|
"one_over_alphas": one_over_alphas.to(device), |
||||
|
"posterior_mean_coef": posterior_mean_coef.to(device), |
||||
|
# DDIM Parameters |
||||
|
"ddim_sigma": ddim_sigma.to(device), |
||||
|
} |
||||
@ -0,0 +1,329 @@ |
|||||
|
import os |
||||
|
import torch |
||||
|
from concurrent.futures import ProcessPoolExecutor |
||||
|
from itertools import repeat |
||||
|
|
||||
|
from scipy.ndimage import label, center_of_mass |
||||
|
import numpy as np |
||||
|
from tqdm import tqdm |
||||
|
from values import GREEN, BLUE |
||||
|
|
||||
|
import matplotlib |
||||
|
import matplotlib.pyplot as plt |
||||
|
import seaborn as sns |
||||
|
|
||||
|
from diffusion import ddpm_sample_images |
||||
|
|
||||
|
matplotlib.use("Agg") |
||||
|
|
||||
|
|
||||
|
def detect_circle_centers(image, circle_colors, threshold=30, min_pixels=12): |
||||
|
""" |
||||
|
Detects centers of circles in an image based on their known colors, filtering out |
||||
|
regions that have less than a specified number of pixels. |
||||
|
|
||||
|
This function creates a mask for each provided circle color by selecting pixels |
||||
|
whose RGB values are close to the target color (within a per-channel threshold). |
||||
|
It then labels the connected regions in the mask, filters out small regions based on |
||||
|
the min_pixels parameter, and computes the centers of the remaining regions. |
||||
|
|
||||
|
Args: |
||||
|
image (np.ndarray): The image array with shape (H, W, 3). |
||||
|
circle_colors (list of tuple): List of RGB tuples for the circle colors to detect, |
||||
|
e.g. [GREEN, BLUE]. |
||||
|
threshold (int): Maximum allowed difference per channel between a pixel and the |
||||
|
target circle color. |
||||
|
min_pixels (int): Minimum number of pixels for a region to be considered valid. |
||||
|
|
||||
|
Returns: |
||||
|
centers (list of tuples): List of (row, col) coordinates for each detected circle. |
||||
|
image (np.ndarray): The original image. |
||||
|
""" |
||||
|
centers = [] |
||||
|
# Loop over each target circle color. |
||||
|
for color in circle_colors: |
||||
|
# Compute absolute difference between each pixel and the target color. |
||||
|
diff = np.abs(image.astype(np.int16) - np.array(color, dtype=np.int16)) |
||||
|
# Create a mask: pixels where all channels are within the threshold. |
||||
|
mask = np.all(diff < threshold, axis=-1) |
||||
|
|
||||
|
# Label connected regions in the mask. |
||||
|
labeled, num_features = label(mask) |
||||
|
|
||||
|
# Process each labeled region. |
||||
|
for i in range(1, num_features + 1): |
||||
|
# Count the number of pixels in the current region. |
||||
|
region_size = np.sum(labeled == i) |
||||
|
# Skip regions that are smaller than the minimum required. |
||||
|
if region_size < min_pixels: |
||||
|
continue |
||||
|
|
||||
|
center = center_of_mass(mask, labeled, i) |
||||
|
centers.append(center) |
||||
|
|
||||
|
return centers, image |
||||
|
|
||||
|
|
||||
|
def compute_statistics( |
||||
|
data: np.array, |
||||
|
generated: np.array, |
||||
|
output_path: str = "center_statistics.png", |
||||
|
): |
||||
|
assert len(data.shape) == 4 |
||||
|
assert data.shape[2] == data.shape[3] |
||||
|
image_size = data.shape[2] |
||||
|
|
||||
|
data_centers = [] |
||||
|
num_bad_samples = 0 |
||||
|
with ProcessPoolExecutor(max_workers=8) as executor: |
||||
|
for centers, _ in executor.map( |
||||
|
detect_circle_centers, data, repeat((BLUE, GREEN)), chunksize=8 |
||||
|
): |
||||
|
if len(centers) == 2: |
||||
|
data_centers.append(np.array(centers)) |
||||
|
else: |
||||
|
num_bad_samples += 1 |
||||
|
|
||||
|
if num_bad_samples > 0: |
||||
|
print("num bad samples in data: ", num_bad_samples) |
||||
|
|
||||
|
data_centers = np.stack(data_centers, axis=0) # (num_samples, 2, 2) |
||||
|
|
||||
|
num_bad_samples = 0 |
||||
|
generated_centers = [] |
||||
|
with ProcessPoolExecutor(max_workers=16) as executor: |
||||
|
for centers, _ in executor.map( |
||||
|
detect_circle_centers, generated, repeat((BLUE, GREEN)), chunksize=8 |
||||
|
): |
||||
|
if len(centers) == 2: |
||||
|
generated_centers.append(np.array(centers)) |
||||
|
else: |
||||
|
num_bad_samples += 1 |
||||
|
|
||||
|
if num_bad_samples > 0: |
||||
|
print("num bad samples in generated: ", num_bad_samples) |
||||
|
|
||||
|
generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2) |
||||
|
|
||||
|
# Create a figure with subplots |
||||
|
plt.figure(figsize=(15, 10)) |
||||
|
|
||||
|
# Plot histogram of x positions |
||||
|
plt.subplot(2, 2, 1) |
||||
|
sns.histplot( |
||||
|
data_centers[:, :, 0].reshape(-1), |
||||
|
color="blue", |
||||
|
label="Data", |
||||
|
kde=True, |
||||
|
stat="density", |
||||
|
) |
||||
|
sns.histplot( |
||||
|
generated_centers[:, :, 0].reshape(-1), |
||||
|
color="orange", |
||||
|
label="Generated", |
||||
|
kde=True, |
||||
|
stat="density", |
||||
|
) |
||||
|
plt.title("X Position Distribution") |
||||
|
plt.xlabel("X Position") |
||||
|
plt.legend() |
||||
|
|
||||
|
# Plot histogram of y positions |
||||
|
plt.subplot(2, 2, 2) |
||||
|
sns.histplot( |
||||
|
data_centers[:, :, 1].reshape(-1), |
||||
|
color="blue", |
||||
|
label="Data", |
||||
|
kde=True, |
||||
|
stat="density", |
||||
|
) |
||||
|
sns.histplot( |
||||
|
generated_centers[:, :, 1].reshape(-1), |
||||
|
color="orange", |
||||
|
label="Generated", |
||||
|
kde=True, |
||||
|
stat="density", |
||||
|
) |
||||
|
plt.title("Y Position Distribution") |
||||
|
plt.xlabel("Y Position") |
||||
|
plt.legend() |
||||
|
|
||||
|
# Plot histogram of distances |
||||
|
plt.subplot(2, 2, 3) |
||||
|
distances = np.sqrt( |
||||
|
np.square(data_centers[:, ::2, 0] - data_centers[:, 1::2, 0]) |
||||
|
+ np.square(data_centers[:, ::2, 1] - data_centers[:, 1::2, 1]) |
||||
|
).squeeze() |
||||
|
generated_distances = np.sqrt( |
||||
|
np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0]) |
||||
|
+ np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1]) |
||||
|
).squeeze() |
||||
|
sns.histplot(distances, color="blue", label="Data", kde=True, stat="density") |
||||
|
sns.histplot( |
||||
|
generated_distances, color="orange", label="Generated", kde=True, stat="density" |
||||
|
) |
||||
|
plt.title("Distance between circles distribution") |
||||
|
plt.xlabel("Distance") |
||||
|
plt.legend() |
||||
|
|
||||
|
# Plot 2D heatmap of center positions |
||||
|
plt.subplot(2, 2, 4) |
||||
|
sns.kdeplot( |
||||
|
x=data_centers[:, :, 0].reshape(-1), |
||||
|
y=data_centers[:, :, 1].reshape(-1), |
||||
|
cmap="Blues", |
||||
|
label="Data", |
||||
|
) |
||||
|
sns.kdeplot( |
||||
|
x=generated_centers[:, :, 0].reshape(-1), |
||||
|
y=generated_centers[:, :, 1].reshape(-1), |
||||
|
cmap="Oranges", |
||||
|
label="Generated", |
||||
|
) |
||||
|
plt.title("2D Heatmap of Center Positions") |
||||
|
plt.xlabel("X Position") |
||||
|
plt.ylabel("Y Position") |
||||
|
plt.xlim(0, image_size) |
||||
|
plt.ylim(0, image_size) |
||||
|
plt.legend() |
||||
|
|
||||
|
plt.tight_layout() |
||||
|
output_path = os.path.join("media", output_path) |
||||
|
plt.savefig(output_path) |
||||
|
print(f"Saved histograms at {output_path}") |
||||
|
plt.close() |
||||
|
|
||||
|
|
||||
|
def plot_k_centers(dataset: np.array, k: int = 1): |
||||
|
"""From a given dataset, plot some samples that have k centers in them.""" |
||||
|
images = [] |
||||
|
with ProcessPoolExecutor(max_workers=8) as executor: |
||||
|
for centers, image in executor.map( |
||||
|
detect_circle_centers, dataset, repeat((BLUE, GREEN)), chunksize=8 |
||||
|
): |
||||
|
if len(centers) == k: |
||||
|
images.append(image) |
||||
|
|
||||
|
print("num samples: ", len(images)) |
||||
|
|
||||
|
images = np.stack(images, axis=0) |
||||
|
show_images(images, f"{k} circles", grid_size=8) |
||||
|
|
||||
|
|
||||
|
def plot_bad_centers(dataset: np.array, threshold: float = 18.0): |
||||
|
"""From a given dataset, plot some samples that violate the distance constraint.""" |
||||
|
generated_centers = [] |
||||
|
num_bad_samples = 0 |
||||
|
with ProcessPoolExecutor(max_workers=8) as executor: |
||||
|
for centers, _ in executor.map( |
||||
|
detect_circle_centers, dataset, repeat((BLUE, GREEN)), chunksize=8 |
||||
|
): |
||||
|
if len(centers) == 2: |
||||
|
generated_centers.append(np.array(centers)) |
||||
|
else: |
||||
|
generated_centers.append(np.zeros((2, 2))) |
||||
|
num_bad_samples += 1 |
||||
|
|
||||
|
if num_bad_samples > 0: |
||||
|
print("num bad samples in generated: ", num_bad_samples) |
||||
|
|
||||
|
generated_centers = np.stack(generated_centers, axis=0) # (num_samples, 2, 2) |
||||
|
generated_distances = np.sqrt( |
||||
|
np.square(generated_centers[:, ::2, 0] - generated_centers[:, 1::2, 0]) |
||||
|
+ np.square(generated_centers[:, ::2, 1] - generated_centers[:, 1::2, 1]) |
||||
|
).squeeze() |
||||
|
|
||||
|
mask = generated_distances > threshold |
||||
|
generated = dataset[mask] |
||||
|
show_images(generated) |
||||
|
|
||||
|
|
||||
|
def plot_more_than_2_centers(generated: np.array): |
||||
|
nb_samples = 0 |
||||
|
images = [] |
||||
|
with ProcessPoolExecutor(max_workers=16) as executor: |
||||
|
for centers, image in executor.map( |
||||
|
detect_circle_centers, generated, ((BLUE, GREEN)), chunksize=8 |
||||
|
): |
||||
|
if len(centers) == 2: |
||||
|
images.append(image) |
||||
|
nb_samples += 1 |
||||
|
|
||||
|
print("Nb images multiple centers: ", nb_samples) |
||||
|
|
||||
|
show_images(generated) |
||||
|
|
||||
|
|
||||
|
def show_images(images, title="", grid_size: int = 4): |
||||
|
"""Display a batch of images in a grid""" |
||||
|
for idx in range(min(grid_size**2, len(images))): |
||||
|
ax = plt.subplot(grid_size, grid_size, idx + 1) |
||||
|
ax.imshow(images[idx]) |
||||
|
centers, _ = detect_circle_centers(images[idx], (BLUE, GREEN)) |
||||
|
for center in centers: |
||||
|
ax.scatter(center[1], center[0], c="red", s=20) |
||||
|
ax.axis("off") |
||||
|
plt.suptitle(title) |
||||
|
plt.savefig("media/circles-predicted.png") |
||||
|
plt.close() |
||||
|
|
||||
|
|
||||
|
def visualize_samples( |
||||
|
dataset: np.array, output_dir: str, filename: str = "sample_grid.png" |
||||
|
): |
||||
|
"""From a given dataset, visualize some samples""" |
||||
|
|
||||
|
# Define the grid size (e.g., 5x5) |
||||
|
grid_size = 5 |
||||
|
fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10)) |
||||
|
|
||||
|
for i in range(grid_size): |
||||
|
for j in range(grid_size): |
||||
|
idx = i * grid_size + j |
||||
|
if idx < len(dataset): |
||||
|
img = dataset[idx] |
||||
|
axes[i, j].imshow(img) |
||||
|
centers, _ = detect_circle_centers(img, (BLUE, GREEN)) |
||||
|
# Plot each detected center in red. Note that center_of_mass returns (row, col) |
||||
|
for center in centers: |
||||
|
axes[i, j].scatter(center[1], center[0], c="red", s=20) |
||||
|
axes[i, j].axis("off") |
||||
|
|
||||
|
plt.tight_layout() |
||||
|
plt.savefig(os.path.join(output_dir, filename)) |
||||
|
plt.close() |
||||
|
|
||||
|
|
||||
|
def generate_synthetic_samples( |
||||
|
model, |
||||
|
image_size: int, |
||||
|
nb_synthetic_samples: int, |
||||
|
params, |
||||
|
device="cuda", |
||||
|
): |
||||
|
model.eval() |
||||
|
generated = np.empty( |
||||
|
shape=(nb_synthetic_samples, image_size, image_size, 3), dtype=np.uint8 |
||||
|
) |
||||
|
chunk = 128 |
||||
|
for i in tqdm( |
||||
|
range(nb_synthetic_samples // chunk), desc="Generating synthetic data." |
||||
|
): |
||||
|
generated_images = ddpm_sample_images( |
||||
|
model=model, |
||||
|
image_size=image_size, |
||||
|
batch_size=chunk, |
||||
|
channels=3, |
||||
|
device=device, |
||||
|
params=params, |
||||
|
) |
||||
|
generated_images = torch.permute(generated_images, (0, 2, 3, 1)) |
||||
|
generated[i * chunk : (i + 1) * chunk] = ( |
||||
|
(generated_images * 255.0) |
||||
|
.clip(min=0.0, max=255.0) |
||||
|
.cpu() |
||||
|
.numpy() |
||||
|
.astype(np.uint8) |
||||
|
) |
||||
|
|
||||
|
return generated |
||||
@ -0,0 +1,29 @@ |
|||||
|
RED = (0xCC, 0x24, 0x1D) |
||||
|
GREEN = (0x98, 0x97, 0x1A) |
||||
|
BLUE = (0x45, 0x85, 0x88) |
||||
|
WHITE = (0xFB, 0xF1, 0xC7) |
||||
|
BACKGROUND = (0x50, 0x49, 0x45) |
||||
|
|
||||
|
EXPERIMENTS = { |
||||
|
"vanilla": { |
||||
|
"image_size": 32, |
||||
|
"distance": 11, |
||||
|
"delta": 4, |
||||
|
"radius": 3, |
||||
|
"rectangle_thickness": 0, |
||||
|
}, |
||||
|
"rectangle": { |
||||
|
"image_size": 32, |
||||
|
"distance": 11, |
||||
|
"delta": 4, |
||||
|
"radius": 3, |
||||
|
"rectangle_thickness": 4, |
||||
|
}, |
||||
|
"bigrectangle": { |
||||
|
"image_size": 32, |
||||
|
"distance": 11, |
||||
|
"delta": 4, |
||||
|
"radius": 3, |
||||
|
"rectangle_thickness": 8, |
||||
|
}, |
||||
|
} |
||||
Loading…
Reference in new issue