Browse Source

feat: check constraints and plots

master
CALVO GONZALEZ Ramon 7 months ago
parent
commit
d185a4d1fe
  1. 116
      train.py

116
train.py

@ -1,3 +1,4 @@
from typing import Optional
import tyro
from functools import partial
from dataclasses import dataclass
@ -40,13 +41,19 @@ class Config:
"""How many samples per mini-batch."""
r1: float = 0.3
"""Inner radius of the donut for p_data"""
"""Inner radius of the donut for p_data."""
r2: float = 0.8
"""Outer radius of the donut for p_data"""
"""Outer radius of the donut for p_data."""
d: float = 0.3
"""Size (radians) of the dent in the donut."""
sample_steps: int = 100
"""The number of steps taken during sampling"""
"""The number of steps taken during sampling."""
evaluate_constraints_every: int = 0
"""Evaluate the constraints after given steps."""
seed: int = 42
"""The seed used for randomness."""
@ -55,16 +62,18 @@ class Config:
# --- Data generation process ----------------------------
@partial(jax.jit, static_argnums=(1,))
def sample_p_data(
key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float
key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float, d: float = 0.0
) -> jax.Array:
key_r, key_t = jax.random.split(key)
u = jax.random.uniform(key_r, (num_samples,), minval=0.0, maxval=1.0)
# radius distribution r => CDF(r) = (r^2 - r1^2)/(r^2 - r1^2) => invert:
# radius distribution r => CDF(r) = (r^2 - rj^2)/(r^2 - r1^2) => invert:
r = jnp.sqrt(u * (r2**2 - r1**2) + r1**2)
# Sample angle uniformly in [0, 2pi]
theta = jax.random.uniform(key_t, (num_samples,), minval=0.0, maxval=2 * jnp.pi)
# Sample angle uniformly in [d/2, 2pi-d/2]
theta = jax.random.uniform(
key_t, (num_samples,), minval=d / 2.0, maxval=2 * jnp.pi - d / 2.0
)
# Convert to cartesian
x = r * jnp.cos(theta)
@ -73,6 +82,47 @@ def sample_p_data(
return jnp.stack((x, y), axis=-1)
def check_donut_constraint(x: jax.Array, r1: float, r2: float) -> jax.Array:
assert len(x.shape) == 2
assert x.shape[1] == 2
r_sq = jnp.square(x[:, 0]) + jnp.square(x[:, 1])
return jnp.logical_and(r_sq > (r1**2), r_sq < (r2**2))
def check_dent_constraint(x: jax.Array, d: float) -> jax.Array:
assert len(x.shape) == 2
assert x.shape[1] == 2
theta = jnp.atan2(x[:, 1], x[:, 0])
return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2))
def array_bool_statistics(x: jax.Array) -> float:
"""Computes the % of True in a bool array."""
assert x.dtype == jnp.bool_
return jnp.count_nonzero(x) / x.size
def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0):
p_data_key, p_t_key = jax.random.split(key)
samples = ode_trajectory(p_t_key, model=model, num_samples=1024, config=config)
stats_donut = array_bool_statistics(
check_donut_constraint(samples, r1=config.r1, r2=config.r2)
)
stats_dent = array_bool_statistics(check_dent_constraint(samples, d=config.d))
if save_plot is not None:
sns.scatterplot(x=samples[:, 0], y=samples[:, 1])
plt.savefig(f"scatter_{save_plot}.png", dpi=300, bbox_inches="tight")
return stats_donut, stats_dent
# --- Model definition -----------------------------------
class MLP(nnx.Module):
def __init__(
@ -234,23 +284,53 @@ def main(config: Config):
cached_train_step = nnx.cached_partial(train_step, model, optim)
stats_donut = []
stats_dent = []
for i in tqdm(range(config.num_steps)):
z = sample_p_data(
rngs.params(), num_samples=config.batch_size, r1=config.r1, r2=config.r2
rngs.params(),
num_samples=config.batch_size,
r1=config.r1,
r2=config.r2,
d=config.d,
)
_ = cached_train_step(z, rngs.params())
# Generate samples
print("sampling...", end="", flush=True)
samples = ode_trajectory(
key=rngs.params(), model=model, num_samples=1024, config=config
)
print(" done!")
# samples = np.array(z)
sns.scatterplot(x=samples[:, 0], y=samples[:, 1])
plt.savefig("scatter.png", dpi=300, bbox_inches="tight")
if (
config.evaluate_constraints_every != 0
and i % config.evaluate_constraints_every == 0
):
stat_donut, stat_dent = evaluate_constraints(rngs.params(), model)
stats_donut.append(stat_donut.item())
stats_dent.append(stat_dent.item())
# Plot results
# 1. Set a “pretty” style
sns.set_theme(style="whitegrid", palette="pastel")
# 2. Create a figure with two side-by-side axes
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
# 3. First subplot: stats_donut
sns.lineplot(data=stats_donut, ax=axes[0])
axes[0].set_title("Donut Statistics", fontsize=14, fontweight="bold")
axes[0].set_xlabel("Index")
axes[0].set_ylabel("% Constraint satisfied")
axes[0].grid(True, linestyle="--", alpha=0.7)
# 4. Second subplot: stats_dent
sns.lineplot(data=stats_dent, ax=axes[1])
axes[1].set_title("Dent Statistics", fontsize=14, fontweight="bold")
axes[1].set_xlabel("Index")
axes[1].set_ylabel("% Constraint satisfied")
axes[1].grid(True, linestyle="--", alpha=0.7)
# 5. Add an overall title and tighten layout
fig.suptitle("Comparison of Donut vs. Dent Trends", fontsize=16, fontweight="bold")
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig("constraint_stats.png")
if __name__ == "__main__":

Loading…
Cancel
Save