|
|
|
@ -1,3 +1,4 @@ |
|
|
|
from typing import Optional |
|
|
|
import tyro |
|
|
|
from functools import partial |
|
|
|
from dataclasses import dataclass |
|
|
|
@ -40,13 +41,19 @@ class Config: |
|
|
|
"""How many samples per mini-batch.""" |
|
|
|
|
|
|
|
r1: float = 0.3 |
|
|
|
"""Inner radius of the donut for p_data""" |
|
|
|
"""Inner radius of the donut for p_data.""" |
|
|
|
|
|
|
|
r2: float = 0.8 |
|
|
|
"""Outer radius of the donut for p_data""" |
|
|
|
"""Outer radius of the donut for p_data.""" |
|
|
|
|
|
|
|
d: float = 0.3 |
|
|
|
"""Size (radians) of the dent in the donut.""" |
|
|
|
|
|
|
|
sample_steps: int = 100 |
|
|
|
"""The number of steps taken during sampling""" |
|
|
|
"""The number of steps taken during sampling.""" |
|
|
|
|
|
|
|
evaluate_constraints_every: int = 0 |
|
|
|
"""Evaluate the constraints after given steps.""" |
|
|
|
|
|
|
|
seed: int = 42 |
|
|
|
"""The seed used for randomness.""" |
|
|
|
@ -55,16 +62,18 @@ class Config: |
|
|
|
# --- Data generation process ---------------------------- |
|
|
|
@partial(jax.jit, static_argnums=(1,)) |
|
|
|
def sample_p_data( |
|
|
|
key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float |
|
|
|
key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float, d: float = 0.0 |
|
|
|
) -> jax.Array: |
|
|
|
key_r, key_t = jax.random.split(key) |
|
|
|
|
|
|
|
u = jax.random.uniform(key_r, (num_samples,), minval=0.0, maxval=1.0) |
|
|
|
# radius distribution r => CDF(r) = (r^2 - r1^2)/(r^2 - r1^2) => invert: |
|
|
|
# radius distribution r => CDF(r) = (r^2 - rj^2)/(r^2 - r1^2) => invert: |
|
|
|
r = jnp.sqrt(u * (r2**2 - r1**2) + r1**2) |
|
|
|
|
|
|
|
# Sample angle uniformly in [0, 2pi] |
|
|
|
theta = jax.random.uniform(key_t, (num_samples,), minval=0.0, maxval=2 * jnp.pi) |
|
|
|
# Sample angle uniformly in [d/2, 2pi-d/2] |
|
|
|
theta = jax.random.uniform( |
|
|
|
key_t, (num_samples,), minval=d / 2.0, maxval=2 * jnp.pi - d / 2.0 |
|
|
|
) |
|
|
|
|
|
|
|
# Convert to cartesian |
|
|
|
x = r * jnp.cos(theta) |
|
|
|
@ -73,6 +82,47 @@ def sample_p_data( |
|
|
|
return jnp.stack((x, y), axis=-1) |
|
|
|
|
|
|
|
|
|
|
|
def check_donut_constraint(x: jax.Array, r1: float, r2: float) -> jax.Array: |
|
|
|
assert len(x.shape) == 2 |
|
|
|
assert x.shape[1] == 2 |
|
|
|
|
|
|
|
r_sq = jnp.square(x[:, 0]) + jnp.square(x[:, 1]) |
|
|
|
return jnp.logical_and(r_sq > (r1**2), r_sq < (r2**2)) |
|
|
|
|
|
|
|
|
|
|
|
def check_dent_constraint(x: jax.Array, d: float) -> jax.Array: |
|
|
|
assert len(x.shape) == 2 |
|
|
|
assert x.shape[1] == 2 |
|
|
|
|
|
|
|
theta = jnp.atan2(x[:, 1], x[:, 0]) |
|
|
|
|
|
|
|
return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2)) |
|
|
|
|
|
|
|
|
|
|
|
def array_bool_statistics(x: jax.Array) -> float: |
|
|
|
"""Computes the % of True in a bool array.""" |
|
|
|
assert x.dtype == jnp.bool_ |
|
|
|
|
|
|
|
return jnp.count_nonzero(x) / x.size |
|
|
|
|
|
|
|
|
|
|
|
def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0): |
|
|
|
p_data_key, p_t_key = jax.random.split(key) |
|
|
|
|
|
|
|
samples = ode_trajectory(p_t_key, model=model, num_samples=1024, config=config) |
|
|
|
|
|
|
|
stats_donut = array_bool_statistics( |
|
|
|
check_donut_constraint(samples, r1=config.r1, r2=config.r2) |
|
|
|
) |
|
|
|
stats_dent = array_bool_statistics(check_dent_constraint(samples, d=config.d)) |
|
|
|
|
|
|
|
if save_plot is not None: |
|
|
|
sns.scatterplot(x=samples[:, 0], y=samples[:, 1]) |
|
|
|
plt.savefig(f"scatter_{save_plot}.png", dpi=300, bbox_inches="tight") |
|
|
|
|
|
|
|
return stats_donut, stats_dent |
|
|
|
|
|
|
|
|
|
|
|
# --- Model definition ----------------------------------- |
|
|
|
class MLP(nnx.Module): |
|
|
|
def __init__( |
|
|
|
@ -234,23 +284,53 @@ def main(config: Config): |
|
|
|
|
|
|
|
cached_train_step = nnx.cached_partial(train_step, model, optim) |
|
|
|
|
|
|
|
stats_donut = [] |
|
|
|
stats_dent = [] |
|
|
|
for i in tqdm(range(config.num_steps)): |
|
|
|
z = sample_p_data( |
|
|
|
rngs.params(), num_samples=config.batch_size, r1=config.r1, r2=config.r2 |
|
|
|
rngs.params(), |
|
|
|
num_samples=config.batch_size, |
|
|
|
r1=config.r1, |
|
|
|
r2=config.r2, |
|
|
|
d=config.d, |
|
|
|
) |
|
|
|
|
|
|
|
_ = cached_train_step(z, rngs.params()) |
|
|
|
|
|
|
|
# Generate samples |
|
|
|
|
|
|
|
print("sampling...", end="", flush=True) |
|
|
|
samples = ode_trajectory( |
|
|
|
key=rngs.params(), model=model, num_samples=1024, config=config |
|
|
|
) |
|
|
|
print(" done!") |
|
|
|
# samples = np.array(z) |
|
|
|
sns.scatterplot(x=samples[:, 0], y=samples[:, 1]) |
|
|
|
plt.savefig("scatter.png", dpi=300, bbox_inches="tight") |
|
|
|
if ( |
|
|
|
config.evaluate_constraints_every != 0 |
|
|
|
and i % config.evaluate_constraints_every == 0 |
|
|
|
): |
|
|
|
stat_donut, stat_dent = evaluate_constraints(rngs.params(), model) |
|
|
|
stats_donut.append(stat_donut.item()) |
|
|
|
stats_dent.append(stat_dent.item()) |
|
|
|
|
|
|
|
# Plot results |
|
|
|
# 1. Set a “pretty” style |
|
|
|
sns.set_theme(style="whitegrid", palette="pastel") |
|
|
|
|
|
|
|
# 2. Create a figure with two side-by-side axes |
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True) |
|
|
|
|
|
|
|
# 3. First subplot: stats_donut |
|
|
|
sns.lineplot(data=stats_donut, ax=axes[0]) |
|
|
|
axes[0].set_title("Donut Statistics", fontsize=14, fontweight="bold") |
|
|
|
axes[0].set_xlabel("Index") |
|
|
|
axes[0].set_ylabel("% Constraint satisfied") |
|
|
|
axes[0].grid(True, linestyle="--", alpha=0.7) |
|
|
|
|
|
|
|
# 4. Second subplot: stats_dent |
|
|
|
sns.lineplot(data=stats_dent, ax=axes[1]) |
|
|
|
axes[1].set_title("Dent Statistics", fontsize=14, fontweight="bold") |
|
|
|
axes[1].set_xlabel("Index") |
|
|
|
axes[1].set_ylabel("% Constraint satisfied") |
|
|
|
axes[1].grid(True, linestyle="--", alpha=0.7) |
|
|
|
|
|
|
|
# 5. Add an overall title and tighten layout |
|
|
|
fig.suptitle("Comparison of Donut vs. Dent Trends", fontsize=16, fontweight="bold") |
|
|
|
fig.tight_layout(rect=[0, 0.03, 1, 0.95]) |
|
|
|
|
|
|
|
plt.savefig("constraint_stats.png") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|