Browse Source

feat: add video dataset generation

master
CALVO GONZALEZ Ramon 8 months ago
parent
commit
d9c4afaca4
  1. 304
      src/ddpm/generate_circle_dataset.py
  2. 105
      src/ddpm/utils.py

304
src/ddpm/generate_circle_dataset.py

@ -1,5 +1,7 @@
from typing import Callable
import argparse import argparse
import numpy as np import numpy as np
import math
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from tqdm import tqdm from tqdm import tqdm
@ -10,13 +12,272 @@ from itertools import repeat
import matplotlib import matplotlib
from utils import visualize_samples from utils import visualize_samples, visualize_sequences
from values import GREEN, BLUE, WHITE, BACKGROUND, EXPERIMENTS from values import GREEN, BLUE, WHITE, BACKGROUND, EXPERIMENTS
matplotlib.use("Agg") matplotlib.use("Agg")
NUM_FRAMES = 3
def calculate_position(pos0, vel, acc, t):
"""Calculates position at time t using initial pos, vel, and const acc."""
x0, y0 = pos0
vx, vy = vel
ax, ay = acc
xt = x0 + vx * t + 0.5 * ax * t**2
yt = y0 + vy * t + 0.5 * ay * t**2
return xt, yt
# --- Modified check_constraints function ---
def check_constraints(
pos0, pos1, size, radius, distance, delta, scale, enable_constraint
):
"""
Checks distance and boundary constraints for a single frame.
Args:
pos0: (x, y) tuple for circle 0 position (high-res).
pos1: (x, y) tuple for circle 1 position (high-res).
size: High-resolution image size.
radius: High-resolution radius.
distance: Target distance between centers (low-res).
delta: Allowed tolerance for distance (low-res).
scale: The scaling factor used for high-resolution. <--- Added argument
enable_constraint: Boolean indicating if distance constraint is active.
Returns:
True if constraints are met, False otherwise.
"""
x0, y0 = pos0
x1, y1 = pos1
# 1. Boundary Check (circles must be fully inside)
# Use floor/ceil to be safe with floating point positions near boundary
if not (
radius <= math.floor(x0)
and math.ceil(x0) <= size - radius
and radius <= math.floor(y0)
and math.ceil(y0) <= size - radius
and radius <= math.floor(x1)
and math.ceil(x1) <= size - radius
and radius <= math.floor(y1)
and math.ceil(y1) <= size - radius
):
# print(f"Boundary fail: ({x0:.1f},{y0:.1f}), ({x1:.1f},{y1:.1f}) in size {size} w/ radius {radius}") # Debugging
return False # Out of bounds
# 2. Distance Check (if enabled)
if enable_constraint:
dist_sq = (x0 - x1) ** 2 + (y0 - y1) ** 2
# Scale the target distance and delta to high-resolution units for comparison
min_dist_highres = (distance - delta) * scale
max_dist_highres = (distance + delta) * scale
# Handle cases where min_dist_highres might be negative if delta > distance
min_dist_sq = (
max(0, min_dist_highres) ** 2
) # Distance squared cannot be negative
max_dist_sq = max_dist_highres**2
# Check if the squared distance is within the allowed squared range
if not (min_dist_sq <= dist_sq <= max_dist_sq):
# Add a small tolerance for floating point comparisons, especially if delta is small
tolerance_sq = 1e-9
if delta > 0 and (
dist_sq < min_dist_sq - tolerance_sq
or dist_sq > max_dist_sq + tolerance_sq
):
# print(f"Distance fail: sqrt({dist_sq:.2f}) not in [{min_dist_highres:.2f}, {max_dist_highres:.2f}]") # Debugging
return False # Distance constraint violated
elif delta == 0 and abs(dist_sq - (distance * scale) ** 2) > tolerance_sq:
# print(f"Exact Distance fail: sqrt({dist_sq:.2f}) != {distance*scale:.2f}") # Debugging
return False # Exact distance constraint violated
return True # All constraints met for this frame
def create_sample_sequence_antialiased(
id: int,
image_size: int,
distance: int,
radius: int,
delta: int,
scale: int = 4,
enable_constraint: bool = True,
enable_rectangle: bool = False,
rectangle_thickness: int = 0,
max_initial_speed: float = 5.0,
max_acceleration: float = 1.0,
):
"""
Generates a sequence of NUM_FRAMES images with two moving circles.
Applies rejection sampling to ensure constraints are met across all frames.
Movement includes linear velocity and constant acceleration (curvature).
"""
high_res_size = image_size * scale
high_res_radius = radius * scale
high_res_max_speed = max_initial_speed * scale
high_res_max_accel = max_acceleration * scale
# Ensure radius isn't too large for the image size
if high_res_radius > high_res_size / 2:
raise ValueError("Radius is too large for the image size.")
attempts = 0
max_attempts = 10_000 # Safety break to prevent infinite loops
while attempts < max_attempts:
attempts += 1
all_frames_valid = True
positions0 = []
positions1 = []
# --- Generate Initial State (Frame 0) and Movement ---
low_bound = high_res_radius
high_bound = high_res_size - high_res_radius
# Ensure initial positions are valid before starting loop
# Add a small buffer if possible to avoid immediate boundary issues
buffer = 1 # Add a small pixel buffer from the edge
safe_low = low_bound + buffer
safe_high = high_bound - buffer
if safe_low >= safe_high: # Handle cases where radius is very large
safe_low = low_bound
safe_high = high_bound + 1e-6 # Add epsilon for randint upper bound
x0_0, y0_0 = np.random.uniform(low=safe_low, high=safe_high, size=2)
x1_0, y1_0 = np.random.uniform(low=safe_low, high=safe_high, size=2)
initial_pos0 = (x0_0, y0_0)
initial_pos1 = (x1_0, y1_0)
angle0 = np.random.uniform(0, 2 * np.pi)
speed0 = np.random.uniform(0, high_res_max_speed)
vel0 = (speed0 * np.cos(angle0), speed0 * np.sin(angle0))
angle1 = np.random.uniform(0, 2 * np.pi)
speed1 = np.random.uniform(0, high_res_max_speed)
vel1 = (speed1 * np.cos(angle1), speed1 * np.sin(angle1))
angle_a0 = np.random.uniform(0, 2 * np.pi)
mag_a0 = np.random.uniform(0, high_res_max_accel)
acc0 = (mag_a0 * np.cos(angle_a0), mag_a0 * np.sin(angle_a0))
angle_a1 = np.random.uniform(0, 2 * np.pi)
mag_a1 = np.random.uniform(0, high_res_max_accel)
acc1 = (mag_a1 * np.cos(angle_a1), mag_a1 * np.sin(angle_a1))
# --- Calculate positions and check constraints for all frames ---
for t in range(NUM_FRAMES):
pos0_t = calculate_position(initial_pos0, vel0, acc0, t)
pos1_t = calculate_position(initial_pos1, vel1, acc1, t)
# --- Updated call to check_constraints ---
if not check_constraints(
pos0_t,
pos1_t,
high_res_size,
high_res_radius,
distance,
delta,
scale,
enable_constraint, # Pass scale here
):
all_frames_valid = False
break # No need to check further frames for this attempt
positions0.append(pos0_t)
positions1.append(pos1_t)
# --- If all frames are valid, generate images ---
if all_frames_valid:
# print(f"Valid sequence found after {attempts} attempts.") # Debugging
frames = []
for t in range(NUM_FRAMES):
x0, y0 = positions0[t]
x1, y1 = positions1[t]
im = Image.new("RGB", (high_res_size, high_res_size), BACKGROUND)
draw = ImageDraw.Draw(im)
if enable_rectangle:
dx = x1 - x0
dy = y1 - y0
d_sq = dx**2 + dy**2
if d_sq > 1e-9:
d = math.sqrt(d_sq)
ux = dx / d
uy = dy / d
else:
ux, uy = 1.0, 0.0
start_x = x0 - high_res_radius * ux
start_y = y0 - high_res_radius * uy
end_x = x1 + high_res_radius * ux
end_y = y1 + high_res_radius * uy
thickness = max(rectangle_thickness * scale, 2 * high_res_radius)
half_thickness = thickness / 2.0
perp_x = -uy
perp_y = ux
p1 = (
start_x + half_thickness * perp_x,
start_y + half_thickness * perp_y,
)
p2 = (
start_x - half_thickness * perp_x,
start_y - half_thickness * perp_y,
)
p3 = (
end_x - half_thickness * perp_x,
end_y - half_thickness * perp_y,
)
p4 = (
end_x + half_thickness * perp_x,
end_y + half_thickness * perp_y,
)
draw.polygon([p1, p2, p3, p4], fill=WHITE)
# Draw circles using ellipse for floating point centers
# Bounding box for ellipse: (left, top, right, bottom)
draw.ellipse(
(
x0 - high_res_radius,
y0 - high_res_radius,
x0 + high_res_radius,
y0 + high_res_radius,
),
fill=GREEN,
)
draw.ellipse(
(
x1 - high_res_radius,
y1 - high_res_radius,
x1 + high_res_radius,
y1 + high_res_radius,
),
fill=BLUE,
)
im_resized = im.resize((image_size, image_size), resample=Image.LANCZOS)
frames.append(np.array(im_resized).astype(np.uint8))
return id, np.stack(frames, axis=0) # Return the sequence of valid frames
# If max_attempts reached without finding a valid sequence
raise RuntimeError(
f"Failed to generate a valid sequence after {max_attempts} attempts. "
"Check constraints, movement parameters, or image size/radius."
)
def create_sample_antialiased( def create_sample_antialiased(
id: int, id: int,
image_size: int, image_size: int,
@ -100,6 +361,7 @@ def generate_circle_dataset(
radius, radius,
distance, distance,
delta, delta,
generation_fn: Callable,
enable_constraint: bool = True, enable_constraint: bool = True,
enable_rectangle: bool = False, enable_rectangle: bool = False,
rectangle_thickness: int = 0, rectangle_thickness: int = 0,
@ -117,7 +379,7 @@ def generate_circle_dataset(
with ProcessPoolExecutor(max_workers=32) as executor: with ProcessPoolExecutor(max_workers=32) as executor:
for i, sample in executor.map( for i, sample in executor.map(
create_sample_antialiased, generation_fn,
range(num_samples), range(num_samples),
repeat(image_size), repeat(image_size),
repeat(distance), repeat(distance),
@ -138,10 +400,19 @@ def main(args, experiment):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
dataset = np.empty((args.total_samples, image_size, image_size, 3), dtype=np.uint8) dataset = (
np.empty((args.total_samples, image_size, image_size, 3), dtype=np.uint8)
if not args.video
else np.empty(
(args.total_samples, NUM_FRAMES, image_size, image_size, 3), dtype=np.uint8
)
)
iterator = generate_circle_dataset( iterator = generate_circle_dataset(
image_size=image_size, image_size=image_size,
num_samples=args.total_samples, num_samples=args.total_samples,
generation_fn=create_sample_antialiased
if not args.video
else create_sample_sequence_antialiased,
distance=experiment["distance"], distance=experiment["distance"],
delta=experiment["delta"], delta=experiment["delta"],
radius=experiment["radius"], radius=experiment["radius"],
@ -152,12 +423,25 @@ def main(args, experiment):
for i, sample in tqdm(iterator, total=args.total_samples): for i, sample in tqdm(iterator, total=args.total_samples):
dataset[i] = sample dataset[i] = sample
visualize_samples( if not args.video:
dataset, args.output_dir, f"sample_grid_{args.experiment_name}.png" visualize_samples(
) dataset, args.output_dir, f"sample_grid_{args.experiment_name}.png"
)
else:
visualize_sequences(
dataset,
args.output_dir,
f"sequence_grid_{args.experiment_name}.png",
num_frames_to_show=NUM_FRAMES,
)
# Save the dataset # Save the dataset
np.save(os.path.join(args.output_dir, f"data-{args.experiment_name}.npy"), dataset) dataset_filename = (
f"data-{args.experiment_name}.npy"
if not args.video
else f"data-video-{args.experiment_name}.npy"
)
np.save(os.path.join(args.output_dir, dataset_filename), dataset)
# np.savez_compressed(os.path.join(output_dir, "data.npy.npz"), dataset) # np.savez_compressed(os.path.join(output_dir, "data.npy.npz"), dataset)
@ -168,6 +452,12 @@ if __name__ == "__main__":
) )
parser.add_argument("-o", "--output_dir", default="data/circle_dataset") parser.add_argument("-o", "--output_dir", default="data/circle_dataset")
parser.add_argument("-t", "--total_samples", default=1_000_000, type=int) parser.add_argument("-t", "--total_samples", default=1_000_000, type=int)
parser.add_argument(
"-v",
"--video",
action="store_true",
help="If set, generate a dataset of videos of moving circles",
)
args = parser.parse_args() args = parser.parse_args()
main(args, EXPERIMENTS[args.experiment_name]) main(args, EXPERIMENTS[args.experiment_name])

105
src/ddpm/utils.py

@ -297,6 +297,111 @@ def visualize_samples(
plt.close() plt.close()
def visualize_sequences(
dataset: np.array,
output_dir: str,
filename: str = "sequence_grid.png",
grid_size: int = 4, # Adjust grid size as needed
num_frames_to_show: int = 3, # Number of frames per sequence (should match dataset)
):
"""
Visualizes random samples from a video dataset.
Each cell in the grid shows the frames of one sequence concatenated horizontally.
Args:
dataset: Numpy array of shape (total_samples, NUM_FRAMES, H, W, C).
output_dir: Directory to save the output grid image.
filename: Name for the output image file.
grid_size: The dimensions of the grid (grid_size x grid_size).
num_frames_to_show: How many frames constitute a sequence in the dataset.
Should match dataset.shape[1].
"""
total_samples = dataset.shape[0]
if total_samples == 0:
print("Dataset is empty, cannot visualize.")
return
# Ensure num_frames_to_show matches the dataset dimension
if dataset.shape[1] != num_frames_to_show:
print(
f"Warning: num_frames_to_show ({num_frames_to_show}) does not match "
f"dataset's second dimension ({dataset.shape[1]}). Using dataset's dimension."
)
num_frames_to_show = dataset.shape[1]
# Get image dimensions from the first sample's first frame
img_h, img_w = dataset.shape[2], dataset.shape[3]
num_cells = grid_size * grid_size
if num_cells > total_samples:
print(
f"Warning: Grid size ({num_cells}) is larger than total samples ({total_samples})."
f" Showing all samples."
)
num_cells = total_samples
# Adjust grid_size down if possible, otherwise plot might have empty cells
grid_size = int(np.ceil(np.sqrt(num_cells)))
fig, axes = plt.subplots(
grid_size, grid_size, figsize=(grid_size * num_frames_to_show, grid_size * 1.2)
) # Adjust figsize based on frames
# Select random indices to display
indices_to_show = np.random.choice(total_samples, size=num_cells, replace=False)
for idx, ax_idx in enumerate(indices_to_show):
row = idx // grid_size
col = idx % grid_size
ax = axes[row, col]
sequence = dataset[ax_idx] # Shape: (NUM_FRAMES, H, W, C)
# Concatenate frames horizontally
# Output shape: (H, NUM_FRAMES * W, C)
concatenated_image = np.concatenate(sequence, axis=1)
ax.imshow(concatenated_image)
ax.set_title(f"Sample {ax_idx}", fontsize=8)
ax.axis("off")
# --- Plot detected centers on the concatenated image ---
for frame_num, frame_img in enumerate(sequence):
try:
centers, _ = detect_circle_centers(frame_img, (BLUE, GREEN))
# Adjust center coordinates for horizontal concatenation
for center in centers:
# center is (row, col)
plot_col = (
center[1] + frame_num * img_w
) # Add offset based on frame index
plot_row = center[0]
ax.scatter(
plot_col, plot_row, c="red", s=10, marker="x"
) # Smaller marker
except Exception as e:
print(
f"Error detecting circles in sample {ax_idx}, frame {frame_num}: {e}"
)
# Handle remaining empty subplots if num_cells < grid_size*grid_size
total_plots = grid_size * grid_size
for idx in range(num_cells, total_plots):
row = idx // grid_size
col = idx % grid_size
if grid_size > 1: # Avoid error if grid_size is 1 and axes is not an array
fig.delaxes(axes[row, col])
else:
fig.delaxes(axes) # Special case for 1x1 grid
plt.tight_layout(pad=0.5, h_pad=1.0, w_pad=0.5) # Adjust padding
output_path = os.path.join(output_dir, filename)
os.makedirs(output_dir, exist_ok=True) # Create dir if it doesn't exist
plt.savefig(output_path)
print(f"Saved sequence visualization grid to {output_path}")
plt.close(fig)
def generate_synthetic_samples( def generate_synthetic_samples(
model, model,
image_size: int, image_size: int,

Loading…
Cancel
Save