Browse Source

feat: add video dataset generation

master
CALVO GONZALEZ Ramon 8 months ago
parent
commit
d9c4afaca4
  1. 298
      src/ddpm/generate_circle_dataset.py
  2. 105
      src/ddpm/utils.py

298
src/ddpm/generate_circle_dataset.py

@ -1,5 +1,7 @@
from typing import Callable
import argparse
import numpy as np
import math
from PIL import Image, ImageDraw
from tqdm import tqdm
@ -10,13 +12,272 @@ from itertools import repeat
import matplotlib
from utils import visualize_samples
from utils import visualize_samples, visualize_sequences
from values import GREEN, BLUE, WHITE, BACKGROUND, EXPERIMENTS
matplotlib.use("Agg")
NUM_FRAMES = 3
def calculate_position(pos0, vel, acc, t):
"""Calculates position at time t using initial pos, vel, and const acc."""
x0, y0 = pos0
vx, vy = vel
ax, ay = acc
xt = x0 + vx * t + 0.5 * ax * t**2
yt = y0 + vy * t + 0.5 * ay * t**2
return xt, yt
# --- Modified check_constraints function ---
def check_constraints(
pos0, pos1, size, radius, distance, delta, scale, enable_constraint
):
"""
Checks distance and boundary constraints for a single frame.
Args:
pos0: (x, y) tuple for circle 0 position (high-res).
pos1: (x, y) tuple for circle 1 position (high-res).
size: High-resolution image size.
radius: High-resolution radius.
distance: Target distance between centers (low-res).
delta: Allowed tolerance for distance (low-res).
scale: The scaling factor used for high-resolution. <--- Added argument
enable_constraint: Boolean indicating if distance constraint is active.
Returns:
True if constraints are met, False otherwise.
"""
x0, y0 = pos0
x1, y1 = pos1
# 1. Boundary Check (circles must be fully inside)
# Use floor/ceil to be safe with floating point positions near boundary
if not (
radius <= math.floor(x0)
and math.ceil(x0) <= size - radius
and radius <= math.floor(y0)
and math.ceil(y0) <= size - radius
and radius <= math.floor(x1)
and math.ceil(x1) <= size - radius
and radius <= math.floor(y1)
and math.ceil(y1) <= size - radius
):
# print(f"Boundary fail: ({x0:.1f},{y0:.1f}), ({x1:.1f},{y1:.1f}) in size {size} w/ radius {radius}") # Debugging
return False # Out of bounds
# 2. Distance Check (if enabled)
if enable_constraint:
dist_sq = (x0 - x1) ** 2 + (y0 - y1) ** 2
# Scale the target distance and delta to high-resolution units for comparison
min_dist_highres = (distance - delta) * scale
max_dist_highres = (distance + delta) * scale
# Handle cases where min_dist_highres might be negative if delta > distance
min_dist_sq = (
max(0, min_dist_highres) ** 2
) # Distance squared cannot be negative
max_dist_sq = max_dist_highres**2
# Check if the squared distance is within the allowed squared range
if not (min_dist_sq <= dist_sq <= max_dist_sq):
# Add a small tolerance for floating point comparisons, especially if delta is small
tolerance_sq = 1e-9
if delta > 0 and (
dist_sq < min_dist_sq - tolerance_sq
or dist_sq > max_dist_sq + tolerance_sq
):
# print(f"Distance fail: sqrt({dist_sq:.2f}) not in [{min_dist_highres:.2f}, {max_dist_highres:.2f}]") # Debugging
return False # Distance constraint violated
elif delta == 0 and abs(dist_sq - (distance * scale) ** 2) > tolerance_sq:
# print(f"Exact Distance fail: sqrt({dist_sq:.2f}) != {distance*scale:.2f}") # Debugging
return False # Exact distance constraint violated
return True # All constraints met for this frame
def create_sample_sequence_antialiased(
id: int,
image_size: int,
distance: int,
radius: int,
delta: int,
scale: int = 4,
enable_constraint: bool = True,
enable_rectangle: bool = False,
rectangle_thickness: int = 0,
max_initial_speed: float = 5.0,
max_acceleration: float = 1.0,
):
"""
Generates a sequence of NUM_FRAMES images with two moving circles.
Applies rejection sampling to ensure constraints are met across all frames.
Movement includes linear velocity and constant acceleration (curvature).
"""
high_res_size = image_size * scale
high_res_radius = radius * scale
high_res_max_speed = max_initial_speed * scale
high_res_max_accel = max_acceleration * scale
# Ensure radius isn't too large for the image size
if high_res_radius > high_res_size / 2:
raise ValueError("Radius is too large for the image size.")
attempts = 0
max_attempts = 10_000 # Safety break to prevent infinite loops
while attempts < max_attempts:
attempts += 1
all_frames_valid = True
positions0 = []
positions1 = []
# --- Generate Initial State (Frame 0) and Movement ---
low_bound = high_res_radius
high_bound = high_res_size - high_res_radius
# Ensure initial positions are valid before starting loop
# Add a small buffer if possible to avoid immediate boundary issues
buffer = 1 # Add a small pixel buffer from the edge
safe_low = low_bound + buffer
safe_high = high_bound - buffer
if safe_low >= safe_high: # Handle cases where radius is very large
safe_low = low_bound
safe_high = high_bound + 1e-6 # Add epsilon for randint upper bound
x0_0, y0_0 = np.random.uniform(low=safe_low, high=safe_high, size=2)
x1_0, y1_0 = np.random.uniform(low=safe_low, high=safe_high, size=2)
initial_pos0 = (x0_0, y0_0)
initial_pos1 = (x1_0, y1_0)
angle0 = np.random.uniform(0, 2 * np.pi)
speed0 = np.random.uniform(0, high_res_max_speed)
vel0 = (speed0 * np.cos(angle0), speed0 * np.sin(angle0))
angle1 = np.random.uniform(0, 2 * np.pi)
speed1 = np.random.uniform(0, high_res_max_speed)
vel1 = (speed1 * np.cos(angle1), speed1 * np.sin(angle1))
angle_a0 = np.random.uniform(0, 2 * np.pi)
mag_a0 = np.random.uniform(0, high_res_max_accel)
acc0 = (mag_a0 * np.cos(angle_a0), mag_a0 * np.sin(angle_a0))
angle_a1 = np.random.uniform(0, 2 * np.pi)
mag_a1 = np.random.uniform(0, high_res_max_accel)
acc1 = (mag_a1 * np.cos(angle_a1), mag_a1 * np.sin(angle_a1))
# --- Calculate positions and check constraints for all frames ---
for t in range(NUM_FRAMES):
pos0_t = calculate_position(initial_pos0, vel0, acc0, t)
pos1_t = calculate_position(initial_pos1, vel1, acc1, t)
# --- Updated call to check_constraints ---
if not check_constraints(
pos0_t,
pos1_t,
high_res_size,
high_res_radius,
distance,
delta,
scale,
enable_constraint, # Pass scale here
):
all_frames_valid = False
break # No need to check further frames for this attempt
positions0.append(pos0_t)
positions1.append(pos1_t)
# --- If all frames are valid, generate images ---
if all_frames_valid:
# print(f"Valid sequence found after {attempts} attempts.") # Debugging
frames = []
for t in range(NUM_FRAMES):
x0, y0 = positions0[t]
x1, y1 = positions1[t]
im = Image.new("RGB", (high_res_size, high_res_size), BACKGROUND)
draw = ImageDraw.Draw(im)
if enable_rectangle:
dx = x1 - x0
dy = y1 - y0
d_sq = dx**2 + dy**2
if d_sq > 1e-9:
d = math.sqrt(d_sq)
ux = dx / d
uy = dy / d
else:
ux, uy = 1.0, 0.0
start_x = x0 - high_res_radius * ux
start_y = y0 - high_res_radius * uy
end_x = x1 + high_res_radius * ux
end_y = y1 + high_res_radius * uy
thickness = max(rectangle_thickness * scale, 2 * high_res_radius)
half_thickness = thickness / 2.0
perp_x = -uy
perp_y = ux
p1 = (
start_x + half_thickness * perp_x,
start_y + half_thickness * perp_y,
)
p2 = (
start_x - half_thickness * perp_x,
start_y - half_thickness * perp_y,
)
p3 = (
end_x - half_thickness * perp_x,
end_y - half_thickness * perp_y,
)
p4 = (
end_x + half_thickness * perp_x,
end_y + half_thickness * perp_y,
)
draw.polygon([p1, p2, p3, p4], fill=WHITE)
# Draw circles using ellipse for floating point centers
# Bounding box for ellipse: (left, top, right, bottom)
draw.ellipse(
(
x0 - high_res_radius,
y0 - high_res_radius,
x0 + high_res_radius,
y0 + high_res_radius,
),
fill=GREEN,
)
draw.ellipse(
(
x1 - high_res_radius,
y1 - high_res_radius,
x1 + high_res_radius,
y1 + high_res_radius,
),
fill=BLUE,
)
im_resized = im.resize((image_size, image_size), resample=Image.LANCZOS)
frames.append(np.array(im_resized).astype(np.uint8))
return id, np.stack(frames, axis=0) # Return the sequence of valid frames
# If max_attempts reached without finding a valid sequence
raise RuntimeError(
f"Failed to generate a valid sequence after {max_attempts} attempts. "
"Check constraints, movement parameters, or image size/radius."
)
def create_sample_antialiased(
id: int,
image_size: int,
@ -100,6 +361,7 @@ def generate_circle_dataset(
radius,
distance,
delta,
generation_fn: Callable,
enable_constraint: bool = True,
enable_rectangle: bool = False,
rectangle_thickness: int = 0,
@ -117,7 +379,7 @@ def generate_circle_dataset(
with ProcessPoolExecutor(max_workers=32) as executor:
for i, sample in executor.map(
create_sample_antialiased,
generation_fn,
range(num_samples),
repeat(image_size),
repeat(distance),
@ -138,10 +400,19 @@ def main(args, experiment):
os.makedirs(args.output_dir, exist_ok=True)
dataset = np.empty((args.total_samples, image_size, image_size, 3), dtype=np.uint8)
dataset = (
np.empty((args.total_samples, image_size, image_size, 3), dtype=np.uint8)
if not args.video
else np.empty(
(args.total_samples, NUM_FRAMES, image_size, image_size, 3), dtype=np.uint8
)
)
iterator = generate_circle_dataset(
image_size=image_size,
num_samples=args.total_samples,
generation_fn=create_sample_antialiased
if not args.video
else create_sample_sequence_antialiased,
distance=experiment["distance"],
delta=experiment["delta"],
radius=experiment["radius"],
@ -152,12 +423,25 @@ def main(args, experiment):
for i, sample in tqdm(iterator, total=args.total_samples):
dataset[i] = sample
if not args.video:
visualize_samples(
dataset, args.output_dir, f"sample_grid_{args.experiment_name}.png"
)
else:
visualize_sequences(
dataset,
args.output_dir,
f"sequence_grid_{args.experiment_name}.png",
num_frames_to_show=NUM_FRAMES,
)
# Save the dataset
np.save(os.path.join(args.output_dir, f"data-{args.experiment_name}.npy"), dataset)
dataset_filename = (
f"data-{args.experiment_name}.npy"
if not args.video
else f"data-video-{args.experiment_name}.npy"
)
np.save(os.path.join(args.output_dir, dataset_filename), dataset)
# np.savez_compressed(os.path.join(output_dir, "data.npy.npz"), dataset)
@ -168,6 +452,12 @@ if __name__ == "__main__":
)
parser.add_argument("-o", "--output_dir", default="data/circle_dataset")
parser.add_argument("-t", "--total_samples", default=1_000_000, type=int)
parser.add_argument(
"-v",
"--video",
action="store_true",
help="If set, generate a dataset of videos of moving circles",
)
args = parser.parse_args()
main(args, EXPERIMENTS[args.experiment_name])

105
src/ddpm/utils.py

@ -297,6 +297,111 @@ def visualize_samples(
plt.close()
def visualize_sequences(
dataset: np.array,
output_dir: str,
filename: str = "sequence_grid.png",
grid_size: int = 4, # Adjust grid size as needed
num_frames_to_show: int = 3, # Number of frames per sequence (should match dataset)
):
"""
Visualizes random samples from a video dataset.
Each cell in the grid shows the frames of one sequence concatenated horizontally.
Args:
dataset: Numpy array of shape (total_samples, NUM_FRAMES, H, W, C).
output_dir: Directory to save the output grid image.
filename: Name for the output image file.
grid_size: The dimensions of the grid (grid_size x grid_size).
num_frames_to_show: How many frames constitute a sequence in the dataset.
Should match dataset.shape[1].
"""
total_samples = dataset.shape[0]
if total_samples == 0:
print("Dataset is empty, cannot visualize.")
return
# Ensure num_frames_to_show matches the dataset dimension
if dataset.shape[1] != num_frames_to_show:
print(
f"Warning: num_frames_to_show ({num_frames_to_show}) does not match "
f"dataset's second dimension ({dataset.shape[1]}). Using dataset's dimension."
)
num_frames_to_show = dataset.shape[1]
# Get image dimensions from the first sample's first frame
img_h, img_w = dataset.shape[2], dataset.shape[3]
num_cells = grid_size * grid_size
if num_cells > total_samples:
print(
f"Warning: Grid size ({num_cells}) is larger than total samples ({total_samples})."
f" Showing all samples."
)
num_cells = total_samples
# Adjust grid_size down if possible, otherwise plot might have empty cells
grid_size = int(np.ceil(np.sqrt(num_cells)))
fig, axes = plt.subplots(
grid_size, grid_size, figsize=(grid_size * num_frames_to_show, grid_size * 1.2)
) # Adjust figsize based on frames
# Select random indices to display
indices_to_show = np.random.choice(total_samples, size=num_cells, replace=False)
for idx, ax_idx in enumerate(indices_to_show):
row = idx // grid_size
col = idx % grid_size
ax = axes[row, col]
sequence = dataset[ax_idx] # Shape: (NUM_FRAMES, H, W, C)
# Concatenate frames horizontally
# Output shape: (H, NUM_FRAMES * W, C)
concatenated_image = np.concatenate(sequence, axis=1)
ax.imshow(concatenated_image)
ax.set_title(f"Sample {ax_idx}", fontsize=8)
ax.axis("off")
# --- Plot detected centers on the concatenated image ---
for frame_num, frame_img in enumerate(sequence):
try:
centers, _ = detect_circle_centers(frame_img, (BLUE, GREEN))
# Adjust center coordinates for horizontal concatenation
for center in centers:
# center is (row, col)
plot_col = (
center[1] + frame_num * img_w
) # Add offset based on frame index
plot_row = center[0]
ax.scatter(
plot_col, plot_row, c="red", s=10, marker="x"
) # Smaller marker
except Exception as e:
print(
f"Error detecting circles in sample {ax_idx}, frame {frame_num}: {e}"
)
# Handle remaining empty subplots if num_cells < grid_size*grid_size
total_plots = grid_size * grid_size
for idx in range(num_cells, total_plots):
row = idx // grid_size
col = idx % grid_size
if grid_size > 1: # Avoid error if grid_size is 1 and axes is not an array
fig.delaxes(axes[row, col])
else:
fig.delaxes(axes) # Special case for 1x1 grid
plt.tight_layout(pad=0.5, h_pad=1.0, w_pad=0.5) # Adjust padding
output_path = os.path.join(output_dir, filename)
os.makedirs(output_dir, exist_ok=True) # Create dir if it doesn't exist
plt.savefig(output_path)
print(f"Saved sequence visualization grid to {output_path}")
plt.close(fig)
def generate_synthetic_samples(
model,
image_size: int,

Loading…
Cancel
Save