You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

595 lines
18 KiB

from typing import Optional
import os
from pathlib import Path
import tyro
from functools import partial
from dataclasses import dataclass
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from filelock import FileLock, Timeout
import numpy as np
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1"
@dataclass(frozen=True)
class Config:
"""Flow/DDM training of a simple distribution."""
space_dimensions: int = 2
"""The dimensionality of the distribution's space."""
brownian_motion: bool = False
"""Enables brownian motion."""
brownian_sigma: float = 0.1
"""The size of the brownian motion."""
num_hidden_layers: int = 4
"""Number of hidden layers in the MLP."""
hidden_size: int = 64
"""The size of the hidden layers of the MLP."""
mlp_bias: bool = True
"""Enable the bias on every layer of the MLP."""
fourier_dim: int = 6
"""Fourier dimensions. Will be concatenated to the input of the MLP."""
fourier_max_period: float = 10_000.0
"""Range of features of the Fourier features."""
num_steps: int = 100_000
"""How many steps of gradient descent to perform."""
batch_size: int = 256
"""How many samples per mini-batch."""
r1: float = 0.3
"""Inner radius of the donut for p_data."""
r2: float = 0.8
"""Outer radius of the donut for p_data."""
d: float = 0.3
"""Size (radians) of the dent in the donut."""
sample_steps: int = 100
"""The number of steps taken during sampling."""
evaluate_constraints_every: int = 0
"""Evaluate the constraints after given steps."""
save_scatterplot: bool = False
"""For every step the constraints are evaluated, save a scatterplot."""
show_p_data: bool = False
"""If set, the script will generate a scatterplot sampled from the p_data process."""
seed: int = 42
"""The seed used for randomness."""
# --- Data generation process ----------------------------
@partial(jax.jit, static_argnums=(1,))
def sample_p_data(
key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float, d: float = 0.0
) -> jax.Array:
key_r, key_t = jax.random.split(key)
u = jax.random.uniform(key_r, (num_samples,), minval=0.0, maxval=1.0)
# radius distribution r => CDF(r) = (r^2 - rj^2)/(r^2 - r1^2) => invert:
r = jnp.sqrt(u * (r2**2 - r1**2) + r1**2)
# Sample angle uniformly in [d/2, 2pi-d/2]
theta = jax.random.uniform(
key_t, (num_samples,), minval=d / 2.0, maxval=2 * jnp.pi - d / 2.0
)
# Convert to cartesian
x = r * jnp.cos(theta)
y = r * jnp.sin(theta)
return jnp.stack((x, y), axis=-1)
def check_donut_constraint(x: jax.Array, r1: float, r2: float) -> jax.Array:
assert len(x.shape) == 2
assert x.shape[1] == 2
r_sq = jnp.square(x[:, 0]) + jnp.square(x[:, 1])
return jnp.logical_and(r_sq > (r1**2), r_sq < (r2**2))
def check_dent_constraint(x: jax.Array, d: float) -> jax.Array:
assert len(x.shape) == 2
assert x.shape[1] == 2
theta = jnp.atan2(x[:, 1], x[:, 0])
return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2))
def brownian_motion(key: jax.random.PRNGKey, x0: jax.Array, sigma: float) -> jax.Array:
noise = sigma * jax.random.normal(key=key, shape=x0.shape)
return x0 + noise
@partial(jax.jit, static_argnums=(1,))
def brownian_motion_step(
key: jax.random.PRNGKey,
num_samples: int,
r1: float,
r2: float,
d: float = 0.0,
sigma: float = 0.1,
num_check_pts: int = 8,
) -> (jax.Array, jax.Array):
key, key_init = jax.random.split(key)
x0 = sample_p_data(key_init, num_samples, r1, r2, d)
def points_satisfy(x: jax.Array) -> jax.Array:
return jnp.logical_and(
check_donut_constraint(x, r1, r2), check_dent_constraint(x, d)
)
ts = jnp.linspace(0.0, 1.0, num_check_pts + 2)[1:-1]
def cond_fn(state):
_, _, cs = state
return jnp.any(~cs)
def body_fn(state):
key, x, cs = state
key, _key = jax.random.split(key)
x_prop = brownian_motion(_key, x0, sigma)
ok_end = points_satisfy(x_prop)
diff = x_prop - x
pts = x[:, None, :] + diff[:, None, :] * ts[None, :, None]
ok_interior = jnp.all(
points_satisfy(pts.reshape(-1, pts.shape[-1])).reshape(num_samples, -1),
axis=1,
)
satisfied = ok_end & ok_interior
x = jnp.where(jnp.logical_and(~cs, satisfied)[:, None], x_prop, x)
cs = jnp.logical_or(cs, satisfied)
return key, x, cs
init_cs = jnp.zeros((num_samples,), dtype=bool)
init_state = (key, x0, init_cs)
_, x1, _ = jax.lax.while_loop(cond_fn, body_fn, init_state)
return x0, x1
def array_bool_statistics(x: jax.Array) -> float:
"""Computes the % of True in a bool array."""
assert x.dtype == jnp.bool_
return jnp.count_nonzero(x) / x.size
def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0):
key, p_data_key, p_t_key = jax.random.split(key, 3)
if config.brownian_motion:
key, key_b = jax.random.split(key)
z0 = sample_p_data(
key, num_samples=32_768, r1=config.r1, r2=config.r2, d=config.d
)
samples = ode_trajectory(
p_t_key,
model=model,
num_samples=32_768,
sample_steps=config.sample_steps,
space_dimensions=config.space_dimensions,
cond=z0 if config.brownian_motion else None,
)
stats_donut = array_bool_statistics(
jnp.logical_not(check_donut_constraint(samples, r1=config.r1, r2=config.r2))
)
stats_dent = array_bool_statistics(
jnp.logical_not(check_dent_constraint(samples, d=config.d))
)
if save_plot is not None:
plt.figure()
sns.scatterplot(x=samples[:, 0], y=samples[:, 1], size=0.1)
save_folder = (
Path("results")
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}"
)
save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig(
f"{save_folder}/scatter_{save_plot}.png", dpi=300, bbox_inches="tight"
)
plt.close()
return stats_donut, stats_dent
# --- Model definition -----------------------------------
class MLP(nnx.Module):
def __init__(
self,
in_features: int,
out_features: int,
fourier_features: int,
num_hidden_layers: int,
hidden_size: int,
use_bias: bool,
fourier_max_period: float,
rngs: nnx.Rngs,
) -> None:
self.fourier_dim = fourier_features
self.fourier_max_period = fourier_max_period
network = [
nnx.Linear(
in_features=in_features + fourier_features,
out_features=hidden_size,
use_bias=use_bias,
rngs=rngs,
),
nnx.silu,
]
for _ in range(num_hidden_layers):
network.append(
nnx.Linear(
in_features=hidden_size,
out_features=hidden_size,
use_bias=use_bias,
rngs=rngs,
)
)
network.append(nnx.silu)
network.append(
nnx.Linear(
in_features=hidden_size,
out_features=out_features,
use_bias=use_bias,
rngs=rngs,
)
)
self.network = nnx.Sequential(*network)
def time_embed(
self, t: jax.Array, embed_dim: int, max_period: float = 10_000.0
) -> jax.Array:
assert embed_dim % 2 == 0, "embed_dim must be even"
if t.shape[-1] != 1:
t = t[..., None]
half_dim = embed_dim // 2
freqs = jnp.exp(
-jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) / half_dim
)
args = t * freqs
time_features = jnp.concatenate([jnp.sin(args), jnp.cos(args)], axis=-1)
return time_features
def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array:
t_encoded = self.time_embed(t, self.fourier_dim, self.fourier_max_period)
x = jnp.concatenate((x, t_encoded), axis=-1)
return self.network(x)
# --- Diffusion functions -----------------------------
def alpha(t: jax.Array) -> jax.Array:
return jnp.clip(t, 0.0, 1.0)
alpha_scalar_grad = jax.grad(alpha)
alpha_grad = jax.jit(jax.vmap(alpha_scalar_grad))
def beta(t: jax.Array) -> jax.Array:
return 1.0 - jnp.clip(t, 0.0, 1.0)
beta_scalar_grad = jax.grad(beta)
beta_grad = jax.jit(jax.vmap(beta_scalar_grad))
def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array:
return x_t[..., :2] + h * model(x_t, t)
@partial(nnx.jit, static_argnums=(2, 3, 4))
def ode_trajectory(
key: jax.random.PRNGKey,
model: MLP,
num_samples: int,
sample_steps: int,
space_dimensions: int,
cond: Optional[jax.Array] = None,
) -> jax.Array:
t = jnp.zeros((num_samples,))
h = 1.0 / sample_steps
x = jax.random.normal(key=key, shape=(num_samples, space_dimensions))
def body(i, state):
t, x = state
x_cond = jnp.concat((x, cond), axis=-1) if cond is not None else x
x = ode_step(model, x_cond, t, h)
return (t + h, x)
_, x = jax.lax.fori_loop(0, config.sample_steps, body, (t, x))
return x
def sde_step(model: MLP, x_t: jax.Array, t: jax.Array, sigma_t: jax.Array) -> jax.Array:
pass
def sde_trajectory(model: MLP) -> jax.Array:
pass
# --- Training ----------------------------------------
@nnx.jit
def train_step(model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey):
key_e, key_t = jax.random.split(key, 2)
eps = jax.random.normal(key=key_e, shape=z.shape)
t = jax.random.uniform(key=key_t, shape=[z.shape[0]])
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps
def loss_fn(model, z, t, eps):
loss = jnp.sum(
(model(x, t) - (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps))
** 2,
axis=-1,
)
return jnp.mean(loss)
value_grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = value_grad_fn(model, z, t, eps)
optim.update(grads)
return loss
@nnx.jit
def train_step_brownian(
model: MLP,
optim: nnx.Optimizer,
z: (jax.Array, jax.Array),
key: jax.random.PRNGKey,
):
z0, z1 = z
assert z0.shape == z1.shape
key_e, key_t = jax.random.split(key)
eps = jax.random.normal(key=key_e, shape=z1.shape)
t = jax.random.uniform(key=key_t, shape=[z1.shape[0]])
x = alpha(t)[:, None] * z1 + beta(t)[:, None] * eps
def loss_fn(model):
x_brownian = jnp.concat((x, z0), axis=-1)
loss = jnp.sum(
(
model(x_brownian, t)
- (alpha_grad(t)[:, None] * beta_grad(t)[:, None] * eps)
)
** 2,
axis=-1,
)
return jnp.mean(loss)
value_grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = value_grad_fn(model)
optim.update(grads)
return loss
def setup_model_and_optimizer(config: Config, rngs: nnx.Rngs) -> (MLP, nnx.Optimizer):
model = MLP(
in_features=2 * config.space_dimensions
if config.brownian_motion
else config.space_dimensions,
out_features=config.space_dimensions,
num_hidden_layers=config.num_hidden_layers,
hidden_size=config.hidden_size,
use_bias=config.mlp_bias,
fourier_features=config.fourier_dim,
fourier_max_period=config.fourier_max_period,
rngs=rngs,
)
optim = nnx.Optimizer(
model,
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)),
)
return model, optim
def update_experiments_log(i, stats_donut, stats_dent, config: Config) -> None:
try:
with FileLock(os.path.join("results", "experiments.lock")):
file = os.path.join("results", "experiments.csv")
if os.path.exists(file):
df = pd.read_csv(file)
else:
df = pd.DataFrame(
columns=[
"num_hidden_layers",
"hidden_size",
"mlp_bias",
"fourier_dim",
"r1",
"r2",
"d",
"sample_steps",
"p_donut",
"p_dent",
"step",
"brownian",
]
)
new_row = pd.DataFrame(
{
"num_hidden_layers": [config.num_hidden_layers],
"hidden_size": [config.hidden_size],
"mlp_bias": [config.mlp_bias],
"fourier_dim": [config.fourier_dim],
"r1": [config.r1],
"r2": [config.r2],
"d": [config.d],
"sample_steps": [config.sample_steps],
"p_donut": [stats_donut[-1]],
"p_dent": [stats_dent[-1]],
"step": [i],
"brownian": [config.brownian_motion],
}
)
df = pd.concat([df, new_row], ignore_index=True)
df.to_csv(file, index=False)
except Timeout:
print("Timeout!!!")
def plot_p_data(key: jax.random.PRNGKey, config: Config):
if not config.brownian_motion:
points = sample_p_data(key, config.batch_size, config.r1, config.r2, config.d)
plt.figure()
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1)
if not os.path.exists("./results"):
os.mkdir("results")
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight")
else:
x0, x1 = brownian_motion_step(
key,
config.batch_size,
config.r1,
config.r2,
config.d,
config.brownian_sigma,
)
# If you have x0, x1 as JAX arrays, first convert:
x0_np = np.array(x0) # shape (1024,2)
x1_np = np.array(x1) # shape (1024,2)
plt.figure(figsize=(8, 8))
# draw a tiny line for each sample
for start, end in zip(x0_np, x1_np):
plt.plot([start[0], end[0]], [start[1], end[1]], linewidth=0.5, alpha=0.5)
# scatter the start points
plt.scatter(x0_np[:, 0], x0_np[:, 1], s=10, label="x₀")
# scatter the end points
plt.scatter(x1_np[:, 0], x1_np[:, 1], s=10, label="x₁")
plt.legend(loc="upper right")
plt.axis("equal")
plt.title("Brownian step in the dented annulus")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.savefig("results/brownian.png")
def main(config: Config):
rngs = nnx.Rngs(config.seed)
if config.show_p_data:
plot_p_data(rngs.params(), config)
exit()
model, optim = setup_model_and_optimizer(config, rngs)
step_fn = train_step_brownian if config.brownian_motion else train_step
step_fn = nnx.cached_partial(step_fn, model, optim)
sampler = partial(
(
partial(brownian_motion_step, sigma=config.brownian_sigma)
if config.brownian_motion
else sample_p_data
),
num_samples=config.batch_size,
r1=config.r1,
r2=config.r2,
d=config.d,
)
a_donut = (jnp.pi - 0.5 * config.d) * (config.r2**2 - config.r1**2)
a_dent = 0.5 * config.d * (config.r2**2 - config.r1**2)
stats_donut = []
stats_dent = []
for i in tqdm(range(config.num_steps)):
z = sampler(rngs.params())
_ = step_fn(z, rngs.params())
if (
config.evaluate_constraints_every != 0
and i > 0
and i % config.evaluate_constraints_every == 0
):
stat_donut, stat_dent = evaluate_constraints(
rngs.params(), model, save_plot=i if config.save_scatterplot else None
)
stats_donut.append(stat_donut.item() / a_donut)
stats_dent.append(stat_dent.item() / a_dent)
update_experiments_log(
i=i, stats_donut=stats_donut, stats_dent=stats_dent, config=config
)
# Plot results
# 1. Set a “pretty” style
sns.set_theme(style="whitegrid", palette="pastel")
# 2. Create a figure with two side-by-side axes
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
# 3. First subplot: stats_donut
sns.lineplot(data=stats_donut, ax=axes[0])
axes[0].set_title("Donut Statistics", fontsize=14, fontweight="bold")
axes[0].set_xlabel("Index")
axes[0].set_ylabel("% Constraint satisfied")
axes[0].grid(True, linestyle="--", alpha=0.7)
# 4. Second subplot: stats_dent
sns.lineplot(data=stats_dent, ax=axes[1])
axes[1].set_title("Dent Statistics", fontsize=14, fontweight="bold")
axes[1].set_xlabel("Index")
axes[1].set_ylabel("% Constraint satisfied")
axes[1].grid(True, linestyle="--", alpha=0.7)
# 5. Add an overall title and tighten layout
fig.suptitle("Comparison of Donut vs. Dent Trends", fontsize=16, fontweight="bold")
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
save_folder = (
Path("results")
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}"
)
save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig(os.path.join(save_folder, "constraint_stats.png"))
plt.close()
if __name__ == "__main__":
config: Config = tyro.cli(Config)
main(config)