You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

188 lines
6.1 KiB

import numpy as np
from scipy.ndimage import label, center_of_mass
from PIL import Image, ImageDraw
from tqdm import tqdm
import os
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat
import matplotlib.pyplot as plt
RED = (0xCC, 0x24, 0x1D)
GREEN = (0x98, 0x97, 0x1A)
BLUE = (0x45, 0x85, 0x88)
BACKGROUND = (0x50, 0x49, 0x45)
def create_sample_antialiased(
id: int, image_size: int, distance: int, radius: int, delta: int, scale=4
):
# Scale up the image dimensions
high_res_size = image_size * scale
high_res_radius = radius * scale
# Create a blank high-res image
im = Image.new("RGB", (high_res_size, high_res_size), BACKGROUND)
draw = ImageDraw.Draw(im)
# Random centers for the two circles at high resolution
dist = float("inf")
while (dist < (distance - delta) * scale) or (dist > (distance + delta) * scale):
x0, y0 = np.random.randint(
low=high_res_radius, high=high_res_size - high_res_radius, size=2
)
x1, y1 = np.random.randint(
low=high_res_radius, high=high_res_size - high_res_radius, size=2
)
dist = np.sqrt((x0 - x1) ** 2 + (y0 - y1) ** 2)
# Draw anti-aliased circles using PIL's ellipse method
draw.circle((x0, y0), high_res_radius, fill=GREEN)
draw.circle((x1, y1), high_res_radius, fill=BLUE)
# Downsample the image back to the target resolution with anti-aliasing
im = im.resize((image_size, image_size), resample=Image.LANCZOS)
return id, np.array(im)
def create_sample(id: int, image_size: int, distance: int, radius: int, delta: int):
# Create a blank image
img = np.full(
shape=(image_size, image_size, 3), fill_value=BACKGROUND, dtype=np.uint8
)
# Compute random centers until they are inside the distance range
dist = float("inf")
while (dist < distance - delta) or (dist > distance + delta):
x0, y0 = np.random.randint(
low=radius, high=image_size - radius, size=2, dtype=np.int32
)
x1, y1 = np.random.randint(
low=radius, high=image_size - radius, size=2, dtype=np.int32
)
dist = np.sqrt((x0 - x1) ** 2 + (y0 - y1) ** 2)
# Draw the circles
xx, yy = np.mgrid[:image_size, :image_size]
# Create boolean masks for the circles based on the radius
mask0 = (xx - x0) ** 2 + (yy - y0) ** 2 <= radius**2
mask1 = (xx - x1) ** 2 + (yy - y1) ** 2 <= radius**2
# Apply the colors to the pixels where the mask is True
img[mask0] = GREEN
img[mask1] = BLUE
return id, img
def detect_circle_centers(image, background=BACKGROUND, threshold=30):
"""
Detects centers of circles in an image by finding connected components
that differ from the background color.
Args:
image (np.ndarray): The image array with shape (H, W, 3).
background (np.ndarray): The background color to ignore.
threshold (int): The minimum per-channel difference to consider a pixel
as part of a circle.
Returns:
centers (list of tuples): List of (row, col) coordinates for each detected circle.
"""
# Compute the absolute difference from the background for each pixel.
diff = np.abs(image.astype(np.int16) - np.array(background).astype(np.int16))
# Create a mask where any channel difference exceeds the threshold.
mask = np.any(diff > threshold, axis=-1)
# Label connected regions in the mask.
labeled, num_features = label(mask)
centers = []
# Compute the center of mass for each labeled region.
for i in range(1, num_features + 1):
center = center_of_mass(mask, labeled, i)
centers.append(center)
return centers
def generate_circle_dataset(num_samples, image_size, radius, distance, delta):
"""
Generate a dataset of images with two circles (red and blue) and save as numpy tensors.
Args:
num_samples (int): Number of images to generate.
image_size (int): Size of the square image (height and width).
radius (int): Radius of the circles.
distance (int): Base distance between the centers of the two circles.
delta (int): Maximum variation in the distance between the circles.
"""
with ProcessPoolExecutor(max_workers=32) as executor:
for i, sample in executor.map(
create_sample_antialiased,
range(num_samples),
repeat(image_size),
repeat(distance),
repeat(radius),
repeat(delta),
chunksize=100,
):
yield i, sample
def visualize_samples(dataset: np.array, output_dir: str):
# Define the grid size (e.g., 5x5)
grid_size = 5
fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10))
for i in range(grid_size):
for j in range(grid_size):
idx = i * grid_size + j
if idx < len(dataset):
img = dataset[idx]
axes[i, j].imshow(img)
centers = detect_circle_centers(img)
# Plot each detected center in red. Note that center_of_mass returns (row, col)
for center in centers:
axes[i, j].scatter(center[1], center[0], c="red", s=20)
axes[i, j].axis("off")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "sample_grid.png"))
plt.close()
if __name__ == "__main__":
# Create output directory if it doesn't exist
total_samples = 1_000_000
image_size = 32
distance = 11
delta = 4
radius = 3
output_dir = "data/circle_dataset"
os.makedirs(output_dir, exist_ok=True)
dataset = np.empty((total_samples, image_size, image_size, 3), dtype=np.uint8)
iterator = generate_circle_dataset(
image_size=image_size,
num_samples=total_samples,
distance=distance,
delta=delta,
radius=radius,
)
for i, sample in tqdm(iterator, total=total_samples):
dataset[i] = sample
visualize_samples(dataset, output_dir)
# Save the dataset
np.save(os.path.join(output_dir, "data32.npy"), dataset)
# np.savez_compressed(os.path.join(output_dir, "data.npy.npz"), dataset)