You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

338 lines
9.6 KiB

from typing import Optional
import tyro
from functools import partial
from dataclasses import dataclass
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
@dataclass
class Config:
"""Flow/DDM training of a simple distribution."""
space_dimensions: int = 2
"""The dimensionality of the distribution's space."""
num_hidden_layers: int = 4
"""Number of hidden layers in the MLP."""
hidden_size: int = 64
"""The size of the hidden layers of the MLP."""
mlp_bias: bool = True
"""Enable the bias on every layer of the MLP."""
fourier_dim: int = 6
"""Fourier dimensions. Will be concatenated to the input of the MLP."""
fourier_max_period: float = 10_000.0
"""Range of features of the Fourier features."""
num_steps: int = 100_000
"""How many steps of gradient descent to perform."""
batch_size: int = 512
"""How many samples per mini-batch."""
r1: float = 0.3
"""Inner radius of the donut for p_data."""
r2: float = 0.8
"""Outer radius of the donut for p_data."""
d: float = 0.3
"""Size (radians) of the dent in the donut."""
sample_steps: int = 100
"""The number of steps taken during sampling."""
evaluate_constraints_every: int = 0
"""Evaluate the constraints after given steps."""
seed: int = 42
"""The seed used for randomness."""
# --- Data generation process ----------------------------
@partial(jax.jit, static_argnums=(1,))
def sample_p_data(
key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float, d: float = 0.0
) -> jax.Array:
key_r, key_t = jax.random.split(key)
u = jax.random.uniform(key_r, (num_samples,), minval=0.0, maxval=1.0)
# radius distribution r => CDF(r) = (r^2 - rj^2)/(r^2 - r1^2) => invert:
r = jnp.sqrt(u * (r2**2 - r1**2) + r1**2)
# Sample angle uniformly in [d/2, 2pi-d/2]
theta = jax.random.uniform(
key_t, (num_samples,), minval=d / 2.0, maxval=2 * jnp.pi - d / 2.0
)
# Convert to cartesian
x = r * jnp.cos(theta)
y = r * jnp.sin(theta)
return jnp.stack((x, y), axis=-1)
def check_donut_constraint(x: jax.Array, r1: float, r2: float) -> jax.Array:
assert len(x.shape) == 2
assert x.shape[1] == 2
r_sq = jnp.square(x[:, 0]) + jnp.square(x[:, 1])
return jnp.logical_and(r_sq > (r1**2), r_sq < (r2**2))
def check_dent_constraint(x: jax.Array, d: float) -> jax.Array:
assert len(x.shape) == 2
assert x.shape[1] == 2
theta = jnp.atan2(x[:, 1], x[:, 0])
return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2))
def array_bool_statistics(x: jax.Array) -> float:
"""Computes the % of True in a bool array."""
assert x.dtype == jnp.bool_
return jnp.count_nonzero(x) / x.size
def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0):
p_data_key, p_t_key = jax.random.split(key)
samples = ode_trajectory(p_t_key, model=model, num_samples=1024, config=config)
stats_donut = array_bool_statistics(
check_donut_constraint(samples, r1=config.r1, r2=config.r2)
)
stats_dent = array_bool_statistics(check_dent_constraint(samples, d=config.d))
if save_plot is not None:
sns.scatterplot(x=samples[:, 0], y=samples[:, 1])
plt.savefig(f"scatter_{save_plot}.png", dpi=300, bbox_inches="tight")
return stats_donut, stats_dent
# --- Model definition -----------------------------------
class MLP(nnx.Module):
def __init__(
self,
in_features: int,
out_features: int,
fourier_features: int,
num_hidden_layers: int,
hidden_size: int,
use_bias: bool,
fourier_max_period: float,
rngs: nnx.Rngs,
) -> None:
self.fourier_dim = fourier_features
self.fourier_max_period = fourier_max_period
network = [
nnx.Linear(
in_features=in_features + fourier_features,
out_features=hidden_size,
use_bias=use_bias,
rngs=rngs,
),
nnx.silu,
]
for _ in range(num_hidden_layers):
network.append(
nnx.Linear(
in_features=hidden_size,
out_features=hidden_size,
use_bias=use_bias,
rngs=rngs,
)
)
network.append(nnx.silu)
network.append(
nnx.Linear(
in_features=hidden_size,
out_features=out_features,
use_bias=use_bias,
rngs=rngs,
)
)
self.network = nnx.Sequential(*network)
def time_embed(
self, t: jax.Array, embed_dim: int, max_period: float = 10_000.0
) -> jax.Array:
assert embed_dim % 2 == 0, "embed_dim must be even"
if t.shape[-1] != 1:
t = t[..., None]
half_dim = embed_dim // 2
freqs = jnp.exp(
-jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) / half_dim
)
args = t * freqs
time_features = jnp.concatenate([jnp.sin(args), jnp.cos(args)], axis=-1)
return time_features
def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array:
t_encoded = self.time_embed(t, self.fourier_dim, self.fourier_max_period)
x = jnp.concatenate((x, t_encoded), axis=-1)
return self.network(x)
# --- Diffusion functions -----------------------------
def alpha(t: jax.Array) -> jax.Array:
return jnp.clip(t, 0.0, 1.0)
alpha_scalar_grad = jax.grad(alpha)
alpha_grad = jax.jit(jax.vmap(alpha_scalar_grad))
def beta(t: jax.Array) -> jax.Array:
return 1.0 - jnp.clip(t, 0.0, 1.0)
beta_scalar_grad = jax.grad(beta)
beta_grad = jax.jit(jax.vmap(beta_scalar_grad))
def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array:
return x_t + h * model(x_t, t)
def ode_trajectory(
key: jax.random.PRNGKey, model: MLP, num_samples: int, config: Config
) -> jax.Array:
t = jnp.zeros((num_samples,))
h = 1.0 / config.sample_steps
x = jax.random.normal(key=key, shape=(num_samples, config.space_dimensions))
for i in range(config.sample_steps):
x = ode_step(model, x, t, h)
t = t + h
return x
def sde_step(model: MLP, x_t: jax.Array, t: jax.Array, sigma_t: jax.Array) -> jax.Array:
pass
def sde_trajectory(model: MLP) -> jax.Array:
pass
# --- Training ----------------------------------------
def main(config: Config):
rngs = nnx.Rngs(config.seed)
model = MLP(
in_features=config.space_dimensions,
out_features=config.space_dimensions,
num_hidden_layers=config.num_hidden_layers,
hidden_size=config.hidden_size,
use_bias=config.mlp_bias,
fourier_features=config.fourier_dim,
fourier_max_period=config.fourier_max_period,
rngs=rngs,
)
optim = nnx.Optimizer(
model,
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)),
)
@nnx.jit
def train_step(
model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey
):
key_e, key_t = jax.random.split(key)
eps = jax.random.normal(key=key_e, shape=z.shape)
t = jax.random.uniform(key=key_t, shape=[z.shape[0]])
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps
def loss_fn(model, z, t, eps):
loss = jnp.sum(
(
model(x, t)
- (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps)
)
** 2,
axis=-1,
)
return jnp.mean(loss)
value_grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = value_grad_fn(model, z, t, eps)
optim.update(grads)
return loss
cached_train_step = nnx.cached_partial(train_step, model, optim)
stats_donut = []
stats_dent = []
for i in tqdm(range(config.num_steps)):
z = sample_p_data(
rngs.params(),
num_samples=config.batch_size,
r1=config.r1,
r2=config.r2,
d=config.d,
)
_ = cached_train_step(z, rngs.params())
if (
config.evaluate_constraints_every != 0
and i % config.evaluate_constraints_every == 0
):
stat_donut, stat_dent = evaluate_constraints(rngs.params(), model)
stats_donut.append(stat_donut.item())
stats_dent.append(stat_dent.item())
# Plot results
# 1. Set a “pretty” style
sns.set_theme(style="whitegrid", palette="pastel")
# 2. Create a figure with two side-by-side axes
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
# 3. First subplot: stats_donut
sns.lineplot(data=stats_donut, ax=axes[0])
axes[0].set_title("Donut Statistics", fontsize=14, fontweight="bold")
axes[0].set_xlabel("Index")
axes[0].set_ylabel("% Constraint satisfied")
axes[0].grid(True, linestyle="--", alpha=0.7)
# 4. Second subplot: stats_dent
sns.lineplot(data=stats_dent, ax=axes[1])
axes[1].set_title("Dent Statistics", fontsize=14, fontweight="bold")
axes[1].set_xlabel("Index")
axes[1].set_ylabel("% Constraint satisfied")
axes[1].grid(True, linestyle="--", alpha=0.7)
# 5. Add an overall title and tighten layout
fig.suptitle("Comparison of Donut vs. Dent Trends", fontsize=16, fontweight="bold")
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig("constraint_stats.png")
if __name__ == "__main__":
config = tyro.cli(Config)
main(config)