diff --git a/train.py b/train.py index 14272bd..40e35ec 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,4 @@ +from typing import Optional import tyro from functools import partial from dataclasses import dataclass @@ -40,13 +41,19 @@ class Config: """How many samples per mini-batch.""" r1: float = 0.3 - """Inner radius of the donut for p_data""" + """Inner radius of the donut for p_data.""" r2: float = 0.8 - """Outer radius of the donut for p_data""" + """Outer radius of the donut for p_data.""" + + d: float = 0.3 + """Size (radians) of the dent in the donut.""" sample_steps: int = 100 - """The number of steps taken during sampling""" + """The number of steps taken during sampling.""" + + evaluate_constraints_every: int = 0 + """Evaluate the constraints after given steps.""" seed: int = 42 """The seed used for randomness.""" @@ -55,16 +62,18 @@ class Config: # --- Data generation process ---------------------------- @partial(jax.jit, static_argnums=(1,)) def sample_p_data( - key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float + key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float, d: float = 0.0 ) -> jax.Array: key_r, key_t = jax.random.split(key) u = jax.random.uniform(key_r, (num_samples,), minval=0.0, maxval=1.0) - # radius distribution r => CDF(r) = (r^2 - r1^2)/(r^2 - r1^2) => invert: + # radius distribution r => CDF(r) = (r^2 - rj^2)/(r^2 - r1^2) => invert: r = jnp.sqrt(u * (r2**2 - r1**2) + r1**2) - # Sample angle uniformly in [0, 2pi] - theta = jax.random.uniform(key_t, (num_samples,), minval=0.0, maxval=2 * jnp.pi) + # Sample angle uniformly in [d/2, 2pi-d/2] + theta = jax.random.uniform( + key_t, (num_samples,), minval=d / 2.0, maxval=2 * jnp.pi - d / 2.0 + ) # Convert to cartesian x = r * jnp.cos(theta) @@ -73,6 +82,47 @@ def sample_p_data( return jnp.stack((x, y), axis=-1) +def check_donut_constraint(x: jax.Array, r1: float, r2: float) -> jax.Array: + assert len(x.shape) == 2 + assert x.shape[1] == 2 + + r_sq = jnp.square(x[:, 0]) + jnp.square(x[:, 1]) + return jnp.logical_and(r_sq > (r1**2), r_sq < (r2**2)) + + +def check_dent_constraint(x: jax.Array, d: float) -> jax.Array: + assert len(x.shape) == 2 + assert x.shape[1] == 2 + + theta = jnp.atan2(x[:, 1], x[:, 0]) + + return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2)) + + +def array_bool_statistics(x: jax.Array) -> float: + """Computes the % of True in a bool array.""" + assert x.dtype == jnp.bool_ + + return jnp.count_nonzero(x) / x.size + + +def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0): + p_data_key, p_t_key = jax.random.split(key) + + samples = ode_trajectory(p_t_key, model=model, num_samples=1024, config=config) + + stats_donut = array_bool_statistics( + check_donut_constraint(samples, r1=config.r1, r2=config.r2) + ) + stats_dent = array_bool_statistics(check_dent_constraint(samples, d=config.d)) + + if save_plot is not None: + sns.scatterplot(x=samples[:, 0], y=samples[:, 1]) + plt.savefig(f"scatter_{save_plot}.png", dpi=300, bbox_inches="tight") + + return stats_donut, stats_dent + + # --- Model definition ----------------------------------- class MLP(nnx.Module): def __init__( @@ -234,23 +284,53 @@ def main(config: Config): cached_train_step = nnx.cached_partial(train_step, model, optim) + stats_donut = [] + stats_dent = [] for i in tqdm(range(config.num_steps)): z = sample_p_data( - rngs.params(), num_samples=config.batch_size, r1=config.r1, r2=config.r2 + rngs.params(), + num_samples=config.batch_size, + r1=config.r1, + r2=config.r2, + d=config.d, ) _ = cached_train_step(z, rngs.params()) - # Generate samples - - print("sampling...", end="", flush=True) - samples = ode_trajectory( - key=rngs.params(), model=model, num_samples=1024, config=config - ) - print(" done!") - # samples = np.array(z) - sns.scatterplot(x=samples[:, 0], y=samples[:, 1]) - plt.savefig("scatter.png", dpi=300, bbox_inches="tight") + if ( + config.evaluate_constraints_every != 0 + and i % config.evaluate_constraints_every == 0 + ): + stat_donut, stat_dent = evaluate_constraints(rngs.params(), model) + stats_donut.append(stat_donut.item()) + stats_dent.append(stat_dent.item()) + + # Plot results + # 1. Set a “pretty” style + sns.set_theme(style="whitegrid", palette="pastel") + + # 2. Create a figure with two side-by-side axes + fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True) + + # 3. First subplot: stats_donut + sns.lineplot(data=stats_donut, ax=axes[0]) + axes[0].set_title("Donut Statistics", fontsize=14, fontweight="bold") + axes[0].set_xlabel("Index") + axes[0].set_ylabel("% Constraint satisfied") + axes[0].grid(True, linestyle="--", alpha=0.7) + + # 4. Second subplot: stats_dent + sns.lineplot(data=stats_dent, ax=axes[1]) + axes[1].set_title("Dent Statistics", fontsize=14, fontweight="bold") + axes[1].set_xlabel("Index") + axes[1].set_ylabel("% Constraint satisfied") + axes[1].grid(True, linestyle="--", alpha=0.7) + + # 5. Add an overall title and tighten layout + fig.suptitle("Comparison of Donut vs. Dent Trends", fontsize=16, fontweight="bold") + fig.tight_layout(rect=[0, 0.03, 1, 0.95]) + + plt.savefig("constraint_stats.png") if __name__ == "__main__":