You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
188 lines
6.1 KiB
188 lines
6.1 KiB
import numpy as np
|
|
from scipy.ndimage import label, center_of_mass
|
|
from PIL import Image, ImageDraw
|
|
|
|
from tqdm import tqdm
|
|
import os
|
|
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from itertools import repeat
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
RED = (0xCC, 0x24, 0x1D)
|
|
GREEN = (0x98, 0x97, 0x1A)
|
|
BLUE = (0x45, 0x85, 0x88)
|
|
BACKGROUND = (0x50, 0x49, 0x45)
|
|
|
|
|
|
def create_sample_antialiased(
|
|
id: int, image_size: int, distance: int, radius: int, delta: int, scale=4
|
|
):
|
|
# Scale up the image dimensions
|
|
high_res_size = image_size * scale
|
|
high_res_radius = radius * scale
|
|
|
|
# Create a blank high-res image
|
|
im = Image.new("RGB", (high_res_size, high_res_size), BACKGROUND)
|
|
draw = ImageDraw.Draw(im)
|
|
|
|
# Random centers for the two circles at high resolution
|
|
dist = float("inf")
|
|
while (dist < (distance - delta) * scale) or (dist > (distance + delta) * scale):
|
|
x0, y0 = np.random.randint(
|
|
low=high_res_radius, high=high_res_size - high_res_radius, size=2
|
|
)
|
|
x1, y1 = np.random.randint(
|
|
low=high_res_radius, high=high_res_size - high_res_radius, size=2
|
|
)
|
|
dist = np.sqrt((x0 - x1) ** 2 + (y0 - y1) ** 2)
|
|
|
|
# Draw anti-aliased circles using PIL's ellipse method
|
|
draw.circle((x0, y0), high_res_radius, fill=GREEN)
|
|
draw.circle((x1, y1), high_res_radius, fill=BLUE)
|
|
|
|
# Downsample the image back to the target resolution with anti-aliasing
|
|
im = im.resize((image_size, image_size), resample=Image.LANCZOS)
|
|
return id, np.array(im)
|
|
|
|
|
|
def create_sample(id: int, image_size: int, distance: int, radius: int, delta: int):
|
|
# Create a blank image
|
|
img = np.full(
|
|
shape=(image_size, image_size, 3), fill_value=BACKGROUND, dtype=np.uint8
|
|
)
|
|
|
|
# Compute random centers until they are inside the distance range
|
|
dist = float("inf")
|
|
while (dist < distance - delta) or (dist > distance + delta):
|
|
x0, y0 = np.random.randint(
|
|
low=radius, high=image_size - radius, size=2, dtype=np.int32
|
|
)
|
|
x1, y1 = np.random.randint(
|
|
low=radius, high=image_size - radius, size=2, dtype=np.int32
|
|
)
|
|
|
|
dist = np.sqrt((x0 - x1) ** 2 + (y0 - y1) ** 2)
|
|
|
|
# Draw the circles
|
|
|
|
xx, yy = np.mgrid[:image_size, :image_size]
|
|
|
|
# Create boolean masks for the circles based on the radius
|
|
mask0 = (xx - x0) ** 2 + (yy - y0) ** 2 <= radius**2
|
|
mask1 = (xx - x1) ** 2 + (yy - y1) ** 2 <= radius**2
|
|
|
|
# Apply the colors to the pixels where the mask is True
|
|
img[mask0] = GREEN
|
|
img[mask1] = BLUE
|
|
|
|
return id, img
|
|
|
|
|
|
def detect_circle_centers(image, background=BACKGROUND, threshold=30):
|
|
"""
|
|
Detects centers of circles in an image by finding connected components
|
|
that differ from the background color.
|
|
|
|
Args:
|
|
image (np.ndarray): The image array with shape (H, W, 3).
|
|
background (np.ndarray): The background color to ignore.
|
|
threshold (int): The minimum per-channel difference to consider a pixel
|
|
as part of a circle.
|
|
|
|
Returns:
|
|
centers (list of tuples): List of (row, col) coordinates for each detected circle.
|
|
"""
|
|
# Compute the absolute difference from the background for each pixel.
|
|
diff = np.abs(image.astype(np.int16) - np.array(background).astype(np.int16))
|
|
# Create a mask where any channel difference exceeds the threshold.
|
|
mask = np.any(diff > threshold, axis=-1)
|
|
|
|
# Label connected regions in the mask.
|
|
labeled, num_features = label(mask)
|
|
|
|
centers = []
|
|
# Compute the center of mass for each labeled region.
|
|
for i in range(1, num_features + 1):
|
|
center = center_of_mass(mask, labeled, i)
|
|
centers.append(center)
|
|
|
|
return centers
|
|
|
|
|
|
def generate_circle_dataset(num_samples, image_size, radius, distance, delta):
|
|
"""
|
|
Generate a dataset of images with two circles (red and blue) and save as numpy tensors.
|
|
|
|
Args:
|
|
num_samples (int): Number of images to generate.
|
|
image_size (int): Size of the square image (height and width).
|
|
radius (int): Radius of the circles.
|
|
distance (int): Base distance between the centers of the two circles.
|
|
delta (int): Maximum variation in the distance between the circles.
|
|
"""
|
|
|
|
with ProcessPoolExecutor(max_workers=32) as executor:
|
|
for i, sample in executor.map(
|
|
create_sample_antialiased,
|
|
range(num_samples),
|
|
repeat(image_size),
|
|
repeat(distance),
|
|
repeat(radius),
|
|
repeat(delta),
|
|
chunksize=100,
|
|
):
|
|
yield i, sample
|
|
|
|
|
|
def visualize_samples(dataset: np.array, output_dir: str):
|
|
# Define the grid size (e.g., 5x5)
|
|
grid_size = 5
|
|
fig, axes = plt.subplots(grid_size, grid_size, figsize=(10, 10))
|
|
|
|
for i in range(grid_size):
|
|
for j in range(grid_size):
|
|
idx = i * grid_size + j
|
|
if idx < len(dataset):
|
|
img = dataset[idx]
|
|
axes[i, j].imshow(img)
|
|
centers = detect_circle_centers(img)
|
|
# Plot each detected center in red. Note that center_of_mass returns (row, col)
|
|
for center in centers:
|
|
axes[i, j].scatter(center[1], center[0], c="red", s=20)
|
|
axes[i, j].axis("off")
|
|
|
|
plt.tight_layout()
|
|
plt.savefig(os.path.join(output_dir, "sample_grid.png"))
|
|
plt.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Create output directory if it doesn't exist
|
|
total_samples = 1_000_000
|
|
image_size = 32
|
|
distance = 11
|
|
delta = 4
|
|
radius = 3
|
|
|
|
output_dir = "data/circle_dataset"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
dataset = np.empty((total_samples, image_size, image_size, 3), dtype=np.uint8)
|
|
iterator = generate_circle_dataset(
|
|
image_size=image_size,
|
|
num_samples=total_samples,
|
|
distance=distance,
|
|
delta=delta,
|
|
radius=radius,
|
|
)
|
|
for i, sample in tqdm(iterator, total=total_samples):
|
|
dataset[i] = sample
|
|
|
|
visualize_samples(dataset, output_dir)
|
|
|
|
# Save the dataset
|
|
np.save(os.path.join(output_dir, "data32.npy"), dataset)
|
|
# np.savez_compressed(os.path.join(output_dir, "data.npy.npz"), dataset)
|
|
|