diff --git a/src/ddpm/generate_circle_dataset.py b/src/ddpm/generate_circle_dataset.py index fdea3f3..431076c 100644 --- a/src/ddpm/generate_circle_dataset.py +++ b/src/ddpm/generate_circle_dataset.py @@ -1,5 +1,7 @@ +from typing import Callable import argparse import numpy as np +import math from PIL import Image, ImageDraw from tqdm import tqdm @@ -10,13 +12,272 @@ from itertools import repeat import matplotlib -from utils import visualize_samples +from utils import visualize_samples, visualize_sequences from values import GREEN, BLUE, WHITE, BACKGROUND, EXPERIMENTS matplotlib.use("Agg") +NUM_FRAMES = 3 + + +def calculate_position(pos0, vel, acc, t): + """Calculates position at time t using initial pos, vel, and const acc.""" + x0, y0 = pos0 + vx, vy = vel + ax, ay = acc + xt = x0 + vx * t + 0.5 * ax * t**2 + yt = y0 + vy * t + 0.5 * ay * t**2 + return xt, yt + + +# --- Modified check_constraints function --- +def check_constraints( + pos0, pos1, size, radius, distance, delta, scale, enable_constraint +): + """ + Checks distance and boundary constraints for a single frame. + + Args: + pos0: (x, y) tuple for circle 0 position (high-res). + pos1: (x, y) tuple for circle 1 position (high-res). + size: High-resolution image size. + radius: High-resolution radius. + distance: Target distance between centers (low-res). + delta: Allowed tolerance for distance (low-res). + scale: The scaling factor used for high-resolution. <--- Added argument + enable_constraint: Boolean indicating if distance constraint is active. + + Returns: + True if constraints are met, False otherwise. + """ + x0, y0 = pos0 + x1, y1 = pos1 + + # 1. Boundary Check (circles must be fully inside) + # Use floor/ceil to be safe with floating point positions near boundary + if not ( + radius <= math.floor(x0) + and math.ceil(x0) <= size - radius + and radius <= math.floor(y0) + and math.ceil(y0) <= size - radius + and radius <= math.floor(x1) + and math.ceil(x1) <= size - radius + and radius <= math.floor(y1) + and math.ceil(y1) <= size - radius + ): + # print(f"Boundary fail: ({x0:.1f},{y0:.1f}), ({x1:.1f},{y1:.1f}) in size {size} w/ radius {radius}") # Debugging + return False # Out of bounds + + # 2. Distance Check (if enabled) + if enable_constraint: + dist_sq = (x0 - x1) ** 2 + (y0 - y1) ** 2 + # Scale the target distance and delta to high-resolution units for comparison + min_dist_highres = (distance - delta) * scale + max_dist_highres = (distance + delta) * scale + + # Handle cases where min_dist_highres might be negative if delta > distance + min_dist_sq = ( + max(0, min_dist_highres) ** 2 + ) # Distance squared cannot be negative + max_dist_sq = max_dist_highres**2 + + # Check if the squared distance is within the allowed squared range + if not (min_dist_sq <= dist_sq <= max_dist_sq): + # Add a small tolerance for floating point comparisons, especially if delta is small + tolerance_sq = 1e-9 + if delta > 0 and ( + dist_sq < min_dist_sq - tolerance_sq + or dist_sq > max_dist_sq + tolerance_sq + ): + # print(f"Distance fail: sqrt({dist_sq:.2f}) not in [{min_dist_highres:.2f}, {max_dist_highres:.2f}]") # Debugging + return False # Distance constraint violated + elif delta == 0 and abs(dist_sq - (distance * scale) ** 2) > tolerance_sq: + # print(f"Exact Distance fail: sqrt({dist_sq:.2f}) != {distance*scale:.2f}") # Debugging + return False # Exact distance constraint violated + + return True # All constraints met for this frame + + +def create_sample_sequence_antialiased( + id: int, + image_size: int, + distance: int, + radius: int, + delta: int, + scale: int = 4, + enable_constraint: bool = True, + enable_rectangle: bool = False, + rectangle_thickness: int = 0, + max_initial_speed: float = 5.0, + max_acceleration: float = 1.0, +): + """ + Generates a sequence of NUM_FRAMES images with two moving circles. + + Applies rejection sampling to ensure constraints are met across all frames. + Movement includes linear velocity and constant acceleration (curvature). + """ + high_res_size = image_size * scale + high_res_radius = radius * scale + high_res_max_speed = max_initial_speed * scale + high_res_max_accel = max_acceleration * scale + + # Ensure radius isn't too large for the image size + if high_res_radius > high_res_size / 2: + raise ValueError("Radius is too large for the image size.") + + attempts = 0 + max_attempts = 10_000 # Safety break to prevent infinite loops + + while attempts < max_attempts: + attempts += 1 + all_frames_valid = True + positions0 = [] + positions1 = [] + + # --- Generate Initial State (Frame 0) and Movement --- + low_bound = high_res_radius + high_bound = high_res_size - high_res_radius + + # Ensure initial positions are valid before starting loop + # Add a small buffer if possible to avoid immediate boundary issues + buffer = 1 # Add a small pixel buffer from the edge + safe_low = low_bound + buffer + safe_high = high_bound - buffer + if safe_low >= safe_high: # Handle cases where radius is very large + safe_low = low_bound + safe_high = high_bound + 1e-6 # Add epsilon for randint upper bound + + x0_0, y0_0 = np.random.uniform(low=safe_low, high=safe_high, size=2) + x1_0, y1_0 = np.random.uniform(low=safe_low, high=safe_high, size=2) + + initial_pos0 = (x0_0, y0_0) + initial_pos1 = (x1_0, y1_0) + + angle0 = np.random.uniform(0, 2 * np.pi) + speed0 = np.random.uniform(0, high_res_max_speed) + vel0 = (speed0 * np.cos(angle0), speed0 * np.sin(angle0)) + + angle1 = np.random.uniform(0, 2 * np.pi) + speed1 = np.random.uniform(0, high_res_max_speed) + vel1 = (speed1 * np.cos(angle1), speed1 * np.sin(angle1)) + + angle_a0 = np.random.uniform(0, 2 * np.pi) + mag_a0 = np.random.uniform(0, high_res_max_accel) + acc0 = (mag_a0 * np.cos(angle_a0), mag_a0 * np.sin(angle_a0)) + + angle_a1 = np.random.uniform(0, 2 * np.pi) + mag_a1 = np.random.uniform(0, high_res_max_accel) + acc1 = (mag_a1 * np.cos(angle_a1), mag_a1 * np.sin(angle_a1)) + + # --- Calculate positions and check constraints for all frames --- + for t in range(NUM_FRAMES): + pos0_t = calculate_position(initial_pos0, vel0, acc0, t) + pos1_t = calculate_position(initial_pos1, vel1, acc1, t) + + # --- Updated call to check_constraints --- + if not check_constraints( + pos0_t, + pos1_t, + high_res_size, + high_res_radius, + distance, + delta, + scale, + enable_constraint, # Pass scale here + ): + all_frames_valid = False + break # No need to check further frames for this attempt + + positions0.append(pos0_t) + positions1.append(pos1_t) + + # --- If all frames are valid, generate images --- + if all_frames_valid: + # print(f"Valid sequence found after {attempts} attempts.") # Debugging + frames = [] + for t in range(NUM_FRAMES): + x0, y0 = positions0[t] + x1, y1 = positions1[t] + + im = Image.new("RGB", (high_res_size, high_res_size), BACKGROUND) + draw = ImageDraw.Draw(im) + + if enable_rectangle: + dx = x1 - x0 + dy = y1 - y0 + d_sq = dx**2 + dy**2 + if d_sq > 1e-9: + d = math.sqrt(d_sq) + ux = dx / d + uy = dy / d + else: + ux, uy = 1.0, 0.0 + + start_x = x0 - high_res_radius * ux + start_y = y0 - high_res_radius * uy + end_x = x1 + high_res_radius * ux + end_y = y1 + high_res_radius * uy + + thickness = max(rectangle_thickness * scale, 2 * high_res_radius) + half_thickness = thickness / 2.0 + perp_x = -uy + perp_y = ux + + p1 = ( + start_x + half_thickness * perp_x, + start_y + half_thickness * perp_y, + ) + p2 = ( + start_x - half_thickness * perp_x, + start_y - half_thickness * perp_y, + ) + p3 = ( + end_x - half_thickness * perp_x, + end_y - half_thickness * perp_y, + ) + p4 = ( + end_x + half_thickness * perp_x, + end_y + half_thickness * perp_y, + ) + + draw.polygon([p1, p2, p3, p4], fill=WHITE) + + # Draw circles using ellipse for floating point centers + # Bounding box for ellipse: (left, top, right, bottom) + draw.ellipse( + ( + x0 - high_res_radius, + y0 - high_res_radius, + x0 + high_res_radius, + y0 + high_res_radius, + ), + fill=GREEN, + ) + draw.ellipse( + ( + x1 - high_res_radius, + y1 - high_res_radius, + x1 + high_res_radius, + y1 + high_res_radius, + ), + fill=BLUE, + ) + + im_resized = im.resize((image_size, image_size), resample=Image.LANCZOS) + frames.append(np.array(im_resized).astype(np.uint8)) + + return id, np.stack(frames, axis=0) # Return the sequence of valid frames + + # If max_attempts reached without finding a valid sequence + raise RuntimeError( + f"Failed to generate a valid sequence after {max_attempts} attempts. " + "Check constraints, movement parameters, or image size/radius." + ) + + def create_sample_antialiased( id: int, image_size: int, @@ -100,6 +361,7 @@ def generate_circle_dataset( radius, distance, delta, + generation_fn: Callable, enable_constraint: bool = True, enable_rectangle: bool = False, rectangle_thickness: int = 0, @@ -117,7 +379,7 @@ def generate_circle_dataset( with ProcessPoolExecutor(max_workers=32) as executor: for i, sample in executor.map( - create_sample_antialiased, + generation_fn, range(num_samples), repeat(image_size), repeat(distance), @@ -138,10 +400,19 @@ def main(args, experiment): os.makedirs(args.output_dir, exist_ok=True) - dataset = np.empty((args.total_samples, image_size, image_size, 3), dtype=np.uint8) + dataset = ( + np.empty((args.total_samples, image_size, image_size, 3), dtype=np.uint8) + if not args.video + else np.empty( + (args.total_samples, NUM_FRAMES, image_size, image_size, 3), dtype=np.uint8 + ) + ) iterator = generate_circle_dataset( image_size=image_size, num_samples=args.total_samples, + generation_fn=create_sample_antialiased + if not args.video + else create_sample_sequence_antialiased, distance=experiment["distance"], delta=experiment["delta"], radius=experiment["radius"], @@ -152,12 +423,25 @@ def main(args, experiment): for i, sample in tqdm(iterator, total=args.total_samples): dataset[i] = sample - visualize_samples( - dataset, args.output_dir, f"sample_grid_{args.experiment_name}.png" - ) + if not args.video: + visualize_samples( + dataset, args.output_dir, f"sample_grid_{args.experiment_name}.png" + ) + else: + visualize_sequences( + dataset, + args.output_dir, + f"sequence_grid_{args.experiment_name}.png", + num_frames_to_show=NUM_FRAMES, + ) # Save the dataset - np.save(os.path.join(args.output_dir, f"data-{args.experiment_name}.npy"), dataset) + dataset_filename = ( + f"data-{args.experiment_name}.npy" + if not args.video + else f"data-video-{args.experiment_name}.npy" + ) + np.save(os.path.join(args.output_dir, dataset_filename), dataset) # np.savez_compressed(os.path.join(output_dir, "data.npy.npz"), dataset) @@ -168,6 +452,12 @@ if __name__ == "__main__": ) parser.add_argument("-o", "--output_dir", default="data/circle_dataset") parser.add_argument("-t", "--total_samples", default=1_000_000, type=int) + parser.add_argument( + "-v", + "--video", + action="store_true", + help="If set, generate a dataset of videos of moving circles", + ) args = parser.parse_args() main(args, EXPERIMENTS[args.experiment_name]) diff --git a/src/ddpm/utils.py b/src/ddpm/utils.py index 9ea0960..a8219b2 100644 --- a/src/ddpm/utils.py +++ b/src/ddpm/utils.py @@ -297,6 +297,111 @@ def visualize_samples( plt.close() +def visualize_sequences( + dataset: np.array, + output_dir: str, + filename: str = "sequence_grid.png", + grid_size: int = 4, # Adjust grid size as needed + num_frames_to_show: int = 3, # Number of frames per sequence (should match dataset) +): + """ + Visualizes random samples from a video dataset. + + Each cell in the grid shows the frames of one sequence concatenated horizontally. + + Args: + dataset: Numpy array of shape (total_samples, NUM_FRAMES, H, W, C). + output_dir: Directory to save the output grid image. + filename: Name for the output image file. + grid_size: The dimensions of the grid (grid_size x grid_size). + num_frames_to_show: How many frames constitute a sequence in the dataset. + Should match dataset.shape[1]. + """ + total_samples = dataset.shape[0] + if total_samples == 0: + print("Dataset is empty, cannot visualize.") + return + + # Ensure num_frames_to_show matches the dataset dimension + if dataset.shape[1] != num_frames_to_show: + print( + f"Warning: num_frames_to_show ({num_frames_to_show}) does not match " + f"dataset's second dimension ({dataset.shape[1]}). Using dataset's dimension." + ) + num_frames_to_show = dataset.shape[1] + + # Get image dimensions from the first sample's first frame + img_h, img_w = dataset.shape[2], dataset.shape[3] + + num_cells = grid_size * grid_size + if num_cells > total_samples: + print( + f"Warning: Grid size ({num_cells}) is larger than total samples ({total_samples})." + f" Showing all samples." + ) + num_cells = total_samples + # Adjust grid_size down if possible, otherwise plot might have empty cells + grid_size = int(np.ceil(np.sqrt(num_cells))) + + fig, axes = plt.subplots( + grid_size, grid_size, figsize=(grid_size * num_frames_to_show, grid_size * 1.2) + ) # Adjust figsize based on frames + + # Select random indices to display + indices_to_show = np.random.choice(total_samples, size=num_cells, replace=False) + + for idx, ax_idx in enumerate(indices_to_show): + row = idx // grid_size + col = idx % grid_size + ax = axes[row, col] + + sequence = dataset[ax_idx] # Shape: (NUM_FRAMES, H, W, C) + + # Concatenate frames horizontally + # Output shape: (H, NUM_FRAMES * W, C) + concatenated_image = np.concatenate(sequence, axis=1) + + ax.imshow(concatenated_image) + ax.set_title(f"Sample {ax_idx}", fontsize=8) + ax.axis("off") + + # --- Plot detected centers on the concatenated image --- + for frame_num, frame_img in enumerate(sequence): + try: + centers, _ = detect_circle_centers(frame_img, (BLUE, GREEN)) + # Adjust center coordinates for horizontal concatenation + for center in centers: + # center is (row, col) + plot_col = ( + center[1] + frame_num * img_w + ) # Add offset based on frame index + plot_row = center[0] + ax.scatter( + plot_col, plot_row, c="red", s=10, marker="x" + ) # Smaller marker + except Exception as e: + print( + f"Error detecting circles in sample {ax_idx}, frame {frame_num}: {e}" + ) + + # Handle remaining empty subplots if num_cells < grid_size*grid_size + total_plots = grid_size * grid_size + for idx in range(num_cells, total_plots): + row = idx // grid_size + col = idx % grid_size + if grid_size > 1: # Avoid error if grid_size is 1 and axes is not an array + fig.delaxes(axes[row, col]) + else: + fig.delaxes(axes) # Special case for 1x1 grid + + plt.tight_layout(pad=0.5, h_pad=1.0, w_pad=0.5) # Adjust padding + output_path = os.path.join(output_dir, filename) + os.makedirs(output_dir, exist_ok=True) # Create dir if it doesn't exist + plt.savefig(output_path) + print(f"Saved sequence visualization grid to {output_path}") + plt.close(fig) + + def generate_synthetic_samples( model, image_size: int,