You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
595 lines
18 KiB
595 lines
18 KiB
from typing import Optional
|
|
import os
|
|
from pathlib import Path
|
|
import tyro
|
|
from functools import partial
|
|
from dataclasses import dataclass
|
|
from tqdm import tqdm
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
from filelock import FileLock, Timeout
|
|
|
|
import numpy as np
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import flax.nnx as nnx
|
|
import optax
|
|
|
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Config:
|
|
"""Flow/DDM training of a simple distribution."""
|
|
|
|
space_dimensions: int = 2
|
|
"""The dimensionality of the distribution's space."""
|
|
|
|
brownian_motion: bool = False
|
|
"""Enables brownian motion."""
|
|
|
|
brownian_sigma: float = 0.1
|
|
"""The size of the brownian motion."""
|
|
|
|
num_hidden_layers: int = 4
|
|
"""Number of hidden layers in the MLP."""
|
|
|
|
hidden_size: int = 64
|
|
"""The size of the hidden layers of the MLP."""
|
|
|
|
mlp_bias: bool = True
|
|
"""Enable the bias on every layer of the MLP."""
|
|
|
|
fourier_dim: int = 6
|
|
"""Fourier dimensions. Will be concatenated to the input of the MLP."""
|
|
|
|
fourier_max_period: float = 10_000.0
|
|
"""Range of features of the Fourier features."""
|
|
|
|
num_steps: int = 100_000
|
|
"""How many steps of gradient descent to perform."""
|
|
|
|
batch_size: int = 256
|
|
"""How many samples per mini-batch."""
|
|
|
|
r1: float = 0.3
|
|
"""Inner radius of the donut for p_data."""
|
|
|
|
r2: float = 0.8
|
|
"""Outer radius of the donut for p_data."""
|
|
|
|
d: float = 0.3
|
|
"""Size (radians) of the dent in the donut."""
|
|
|
|
sample_steps: int = 100
|
|
"""The number of steps taken during sampling."""
|
|
|
|
evaluate_constraints_every: int = 0
|
|
"""Evaluate the constraints after given steps."""
|
|
|
|
save_scatterplot: bool = False
|
|
"""For every step the constraints are evaluated, save a scatterplot."""
|
|
|
|
show_p_data: bool = False
|
|
"""If set, the script will generate a scatterplot sampled from the p_data process."""
|
|
|
|
seed: int = 42
|
|
"""The seed used for randomness."""
|
|
|
|
|
|
# --- Data generation process ----------------------------
|
|
@partial(jax.jit, static_argnums=(1,))
|
|
def sample_p_data(
|
|
key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float, d: float = 0.0
|
|
) -> jax.Array:
|
|
key_r, key_t = jax.random.split(key)
|
|
|
|
u = jax.random.uniform(key_r, (num_samples,), minval=0.0, maxval=1.0)
|
|
# radius distribution r => CDF(r) = (r^2 - rj^2)/(r^2 - r1^2) => invert:
|
|
r = jnp.sqrt(u * (r2**2 - r1**2) + r1**2)
|
|
|
|
# Sample angle uniformly in [d/2, 2pi-d/2]
|
|
theta = jax.random.uniform(
|
|
key_t, (num_samples,), minval=d / 2.0, maxval=2 * jnp.pi - d / 2.0
|
|
)
|
|
|
|
# Convert to cartesian
|
|
x = r * jnp.cos(theta)
|
|
y = r * jnp.sin(theta)
|
|
|
|
return jnp.stack((x, y), axis=-1)
|
|
|
|
|
|
def check_donut_constraint(x: jax.Array, r1: float, r2: float) -> jax.Array:
|
|
assert len(x.shape) == 2
|
|
assert x.shape[1] == 2
|
|
|
|
r_sq = jnp.square(x[:, 0]) + jnp.square(x[:, 1])
|
|
return jnp.logical_and(r_sq > (r1**2), r_sq < (r2**2))
|
|
|
|
|
|
def check_dent_constraint(x: jax.Array, d: float) -> jax.Array:
|
|
assert len(x.shape) == 2
|
|
assert x.shape[1] == 2
|
|
|
|
theta = jnp.atan2(x[:, 1], x[:, 0])
|
|
|
|
return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2))
|
|
|
|
|
|
def brownian_motion(key: jax.random.PRNGKey, x0: jax.Array, sigma: float) -> jax.Array:
|
|
noise = sigma * jax.random.normal(key=key, shape=x0.shape)
|
|
return x0 + noise
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(1,))
|
|
def brownian_motion_step(
|
|
key: jax.random.PRNGKey,
|
|
num_samples: int,
|
|
r1: float,
|
|
r2: float,
|
|
d: float = 0.0,
|
|
sigma: float = 0.1,
|
|
num_check_pts: int = 8,
|
|
) -> (jax.Array, jax.Array):
|
|
key, key_init = jax.random.split(key)
|
|
x0 = sample_p_data(key_init, num_samples, r1, r2, d)
|
|
|
|
def points_satisfy(x: jax.Array) -> jax.Array:
|
|
return jnp.logical_and(
|
|
check_donut_constraint(x, r1, r2), check_dent_constraint(x, d)
|
|
)
|
|
|
|
ts = jnp.linspace(0.0, 1.0, num_check_pts + 2)[1:-1]
|
|
|
|
def cond_fn(state):
|
|
_, _, cs = state
|
|
return jnp.any(~cs)
|
|
|
|
def body_fn(state):
|
|
key, x, cs = state
|
|
key, _key = jax.random.split(key)
|
|
x_prop = brownian_motion(_key, x0, sigma)
|
|
ok_end = points_satisfy(x_prop)
|
|
|
|
diff = x_prop - x
|
|
pts = x[:, None, :] + diff[:, None, :] * ts[None, :, None]
|
|
ok_interior = jnp.all(
|
|
points_satisfy(pts.reshape(-1, pts.shape[-1])).reshape(num_samples, -1),
|
|
axis=1,
|
|
)
|
|
satisfied = ok_end & ok_interior
|
|
x = jnp.where(jnp.logical_and(~cs, satisfied)[:, None], x_prop, x)
|
|
cs = jnp.logical_or(cs, satisfied)
|
|
return key, x, cs
|
|
|
|
init_cs = jnp.zeros((num_samples,), dtype=bool)
|
|
init_state = (key, x0, init_cs)
|
|
_, x1, _ = jax.lax.while_loop(cond_fn, body_fn, init_state)
|
|
|
|
return x0, x1
|
|
|
|
|
|
def array_bool_statistics(x: jax.Array) -> float:
|
|
"""Computes the % of True in a bool array."""
|
|
assert x.dtype == jnp.bool_
|
|
|
|
return jnp.count_nonzero(x) / x.size
|
|
|
|
|
|
def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0):
|
|
key, p_data_key, p_t_key = jax.random.split(key, 3)
|
|
|
|
if config.brownian_motion:
|
|
key, key_b = jax.random.split(key)
|
|
z0 = sample_p_data(
|
|
key, num_samples=32_768, r1=config.r1, r2=config.r2, d=config.d
|
|
)
|
|
|
|
samples = ode_trajectory(
|
|
p_t_key,
|
|
model=model,
|
|
num_samples=32_768,
|
|
sample_steps=config.sample_steps,
|
|
space_dimensions=config.space_dimensions,
|
|
cond=z0 if config.brownian_motion else None,
|
|
)
|
|
|
|
stats_donut = array_bool_statistics(
|
|
jnp.logical_not(check_donut_constraint(samples, r1=config.r1, r2=config.r2))
|
|
)
|
|
stats_dent = array_bool_statistics(
|
|
jnp.logical_not(check_dent_constraint(samples, d=config.d))
|
|
)
|
|
|
|
if save_plot is not None:
|
|
plt.figure()
|
|
sns.scatterplot(x=samples[:, 0], y=samples[:, 1], size=0.1)
|
|
save_folder = (
|
|
Path("results")
|
|
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}"
|
|
)
|
|
save_folder.mkdir(parents=True, exist_ok=True)
|
|
plt.savefig(
|
|
f"{save_folder}/scatter_{save_plot}.png", dpi=300, bbox_inches="tight"
|
|
)
|
|
plt.close()
|
|
|
|
return stats_donut, stats_dent
|
|
|
|
|
|
# --- Model definition -----------------------------------
|
|
class MLP(nnx.Module):
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
fourier_features: int,
|
|
num_hidden_layers: int,
|
|
hidden_size: int,
|
|
use_bias: bool,
|
|
fourier_max_period: float,
|
|
rngs: nnx.Rngs,
|
|
) -> None:
|
|
self.fourier_dim = fourier_features
|
|
self.fourier_max_period = fourier_max_period
|
|
|
|
network = [
|
|
nnx.Linear(
|
|
in_features=in_features + fourier_features,
|
|
out_features=hidden_size,
|
|
use_bias=use_bias,
|
|
rngs=rngs,
|
|
),
|
|
nnx.silu,
|
|
]
|
|
for _ in range(num_hidden_layers):
|
|
network.append(
|
|
nnx.Linear(
|
|
in_features=hidden_size,
|
|
out_features=hidden_size,
|
|
use_bias=use_bias,
|
|
rngs=rngs,
|
|
)
|
|
)
|
|
network.append(nnx.silu)
|
|
|
|
network.append(
|
|
nnx.Linear(
|
|
in_features=hidden_size,
|
|
out_features=out_features,
|
|
use_bias=use_bias,
|
|
rngs=rngs,
|
|
)
|
|
)
|
|
|
|
self.network = nnx.Sequential(*network)
|
|
|
|
def time_embed(
|
|
self, t: jax.Array, embed_dim: int, max_period: float = 10_000.0
|
|
) -> jax.Array:
|
|
assert embed_dim % 2 == 0, "embed_dim must be even"
|
|
|
|
if t.shape[-1] != 1:
|
|
t = t[..., None]
|
|
|
|
half_dim = embed_dim // 2
|
|
freqs = jnp.exp(
|
|
-jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) / half_dim
|
|
)
|
|
args = t * freqs
|
|
time_features = jnp.concatenate([jnp.sin(args), jnp.cos(args)], axis=-1)
|
|
return time_features
|
|
|
|
def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array:
|
|
t_encoded = self.time_embed(t, self.fourier_dim, self.fourier_max_period)
|
|
x = jnp.concatenate((x, t_encoded), axis=-1)
|
|
return self.network(x)
|
|
|
|
|
|
# --- Diffusion functions -----------------------------
|
|
def alpha(t: jax.Array) -> jax.Array:
|
|
return jnp.clip(t, 0.0, 1.0)
|
|
|
|
|
|
alpha_scalar_grad = jax.grad(alpha)
|
|
alpha_grad = jax.jit(jax.vmap(alpha_scalar_grad))
|
|
|
|
|
|
def beta(t: jax.Array) -> jax.Array:
|
|
return 1.0 - jnp.clip(t, 0.0, 1.0)
|
|
|
|
|
|
beta_scalar_grad = jax.grad(beta)
|
|
beta_grad = jax.jit(jax.vmap(beta_scalar_grad))
|
|
|
|
|
|
def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array:
|
|
return x_t[..., :2] + h * model(x_t, t)
|
|
|
|
|
|
@partial(nnx.jit, static_argnums=(2, 3, 4))
|
|
def ode_trajectory(
|
|
key: jax.random.PRNGKey,
|
|
model: MLP,
|
|
num_samples: int,
|
|
sample_steps: int,
|
|
space_dimensions: int,
|
|
cond: Optional[jax.Array] = None,
|
|
) -> jax.Array:
|
|
t = jnp.zeros((num_samples,))
|
|
h = 1.0 / sample_steps
|
|
x = jax.random.normal(key=key, shape=(num_samples, space_dimensions))
|
|
|
|
def body(i, state):
|
|
t, x = state
|
|
x_cond = jnp.concat((x, cond), axis=-1) if cond is not None else x
|
|
x = ode_step(model, x_cond, t, h)
|
|
return (t + h, x)
|
|
|
|
_, x = jax.lax.fori_loop(0, config.sample_steps, body, (t, x))
|
|
return x
|
|
|
|
|
|
def sde_step(model: MLP, x_t: jax.Array, t: jax.Array, sigma_t: jax.Array) -> jax.Array:
|
|
pass
|
|
|
|
|
|
def sde_trajectory(model: MLP) -> jax.Array:
|
|
pass
|
|
|
|
|
|
# --- Training ----------------------------------------
|
|
|
|
|
|
@nnx.jit
|
|
def train_step(model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey):
|
|
key_e, key_t = jax.random.split(key, 2)
|
|
eps = jax.random.normal(key=key_e, shape=z.shape)
|
|
t = jax.random.uniform(key=key_t, shape=[z.shape[0]])
|
|
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps
|
|
|
|
def loss_fn(model, z, t, eps):
|
|
loss = jnp.sum(
|
|
(model(x, t) - (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps))
|
|
** 2,
|
|
axis=-1,
|
|
)
|
|
return jnp.mean(loss)
|
|
|
|
value_grad_fn = nnx.value_and_grad(loss_fn)
|
|
loss, grads = value_grad_fn(model, z, t, eps)
|
|
|
|
optim.update(grads)
|
|
return loss
|
|
|
|
|
|
@nnx.jit
|
|
def train_step_brownian(
|
|
model: MLP,
|
|
optim: nnx.Optimizer,
|
|
z: (jax.Array, jax.Array),
|
|
key: jax.random.PRNGKey,
|
|
):
|
|
z0, z1 = z
|
|
assert z0.shape == z1.shape
|
|
|
|
key_e, key_t = jax.random.split(key)
|
|
eps = jax.random.normal(key=key_e, shape=z1.shape)
|
|
t = jax.random.uniform(key=key_t, shape=[z1.shape[0]])
|
|
x = alpha(t)[:, None] * z1 + beta(t)[:, None] * eps
|
|
|
|
def loss_fn(model):
|
|
x_brownian = jnp.concat((x, z0), axis=-1)
|
|
loss = jnp.sum(
|
|
(
|
|
model(x_brownian, t)
|
|
- (alpha_grad(t)[:, None] * beta_grad(t)[:, None] * eps)
|
|
)
|
|
** 2,
|
|
axis=-1,
|
|
)
|
|
return jnp.mean(loss)
|
|
|
|
value_grad_fn = nnx.value_and_grad(loss_fn)
|
|
loss, grads = value_grad_fn(model)
|
|
|
|
optim.update(grads)
|
|
return loss
|
|
|
|
|
|
def setup_model_and_optimizer(config: Config, rngs: nnx.Rngs) -> (MLP, nnx.Optimizer):
|
|
model = MLP(
|
|
in_features=2 * config.space_dimensions
|
|
if config.brownian_motion
|
|
else config.space_dimensions,
|
|
out_features=config.space_dimensions,
|
|
num_hidden_layers=config.num_hidden_layers,
|
|
hidden_size=config.hidden_size,
|
|
use_bias=config.mlp_bias,
|
|
fourier_features=config.fourier_dim,
|
|
fourier_max_period=config.fourier_max_period,
|
|
rngs=rngs,
|
|
)
|
|
optim = nnx.Optimizer(
|
|
model,
|
|
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)),
|
|
)
|
|
|
|
return model, optim
|
|
|
|
|
|
def update_experiments_log(i, stats_donut, stats_dent, config: Config) -> None:
|
|
try:
|
|
with FileLock(os.path.join("results", "experiments.lock")):
|
|
file = os.path.join("results", "experiments.csv")
|
|
if os.path.exists(file):
|
|
df = pd.read_csv(file)
|
|
else:
|
|
df = pd.DataFrame(
|
|
columns=[
|
|
"num_hidden_layers",
|
|
"hidden_size",
|
|
"mlp_bias",
|
|
"fourier_dim",
|
|
"r1",
|
|
"r2",
|
|
"d",
|
|
"sample_steps",
|
|
"p_donut",
|
|
"p_dent",
|
|
"step",
|
|
"brownian",
|
|
]
|
|
)
|
|
|
|
new_row = pd.DataFrame(
|
|
{
|
|
"num_hidden_layers": [config.num_hidden_layers],
|
|
"hidden_size": [config.hidden_size],
|
|
"mlp_bias": [config.mlp_bias],
|
|
"fourier_dim": [config.fourier_dim],
|
|
"r1": [config.r1],
|
|
"r2": [config.r2],
|
|
"d": [config.d],
|
|
"sample_steps": [config.sample_steps],
|
|
"p_donut": [stats_donut[-1]],
|
|
"p_dent": [stats_dent[-1]],
|
|
"step": [i],
|
|
"brownian": [config.brownian_motion],
|
|
}
|
|
)
|
|
|
|
df = pd.concat([df, new_row], ignore_index=True)
|
|
df.to_csv(file, index=False)
|
|
except Timeout:
|
|
print("Timeout!!!")
|
|
|
|
|
|
def plot_p_data(key: jax.random.PRNGKey, config: Config):
|
|
if not config.brownian_motion:
|
|
points = sample_p_data(key, config.batch_size, config.r1, config.r2, config.d)
|
|
plt.figure()
|
|
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1)
|
|
if not os.path.exists("./results"):
|
|
os.mkdir("results")
|
|
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight")
|
|
else:
|
|
x0, x1 = brownian_motion_step(
|
|
key,
|
|
config.batch_size,
|
|
config.r1,
|
|
config.r2,
|
|
config.d,
|
|
config.brownian_sigma,
|
|
)
|
|
|
|
# If you have x0, x1 as JAX arrays, first convert:
|
|
x0_np = np.array(x0) # shape (1024,2)
|
|
x1_np = np.array(x1) # shape (1024,2)
|
|
|
|
plt.figure(figsize=(8, 8))
|
|
|
|
# draw a tiny line for each sample
|
|
for start, end in zip(x0_np, x1_np):
|
|
plt.plot([start[0], end[0]], [start[1], end[1]], linewidth=0.5, alpha=0.5)
|
|
|
|
# scatter the start points
|
|
plt.scatter(x0_np[:, 0], x0_np[:, 1], s=10, label="x₀")
|
|
|
|
# scatter the end points
|
|
plt.scatter(x1_np[:, 0], x1_np[:, 1], s=10, label="x₁")
|
|
|
|
plt.legend(loc="upper right")
|
|
plt.axis("equal")
|
|
plt.title("Brownian step in the dented annulus")
|
|
plt.xlabel("x")
|
|
plt.ylabel("y")
|
|
plt.tight_layout()
|
|
plt.savefig("results/brownian.png")
|
|
|
|
|
|
def main(config: Config):
|
|
rngs = nnx.Rngs(config.seed)
|
|
|
|
if config.show_p_data:
|
|
plot_p_data(rngs.params(), config)
|
|
exit()
|
|
|
|
model, optim = setup_model_and_optimizer(config, rngs)
|
|
|
|
step_fn = train_step_brownian if config.brownian_motion else train_step
|
|
step_fn = nnx.cached_partial(step_fn, model, optim)
|
|
sampler = partial(
|
|
(
|
|
partial(brownian_motion_step, sigma=config.brownian_sigma)
|
|
if config.brownian_motion
|
|
else sample_p_data
|
|
),
|
|
num_samples=config.batch_size,
|
|
r1=config.r1,
|
|
r2=config.r2,
|
|
d=config.d,
|
|
)
|
|
|
|
a_donut = (jnp.pi - 0.5 * config.d) * (config.r2**2 - config.r1**2)
|
|
a_dent = 0.5 * config.d * (config.r2**2 - config.r1**2)
|
|
stats_donut = []
|
|
stats_dent = []
|
|
for i in tqdm(range(config.num_steps)):
|
|
z = sampler(rngs.params())
|
|
_ = step_fn(z, rngs.params())
|
|
|
|
if (
|
|
config.evaluate_constraints_every != 0
|
|
and i > 0
|
|
and i % config.evaluate_constraints_every == 0
|
|
):
|
|
stat_donut, stat_dent = evaluate_constraints(
|
|
rngs.params(), model, save_plot=i if config.save_scatterplot else None
|
|
)
|
|
stats_donut.append(stat_donut.item() / a_donut)
|
|
stats_dent.append(stat_dent.item() / a_dent)
|
|
|
|
update_experiments_log(
|
|
i=i, stats_donut=stats_donut, stats_dent=stats_dent, config=config
|
|
)
|
|
|
|
# Plot results
|
|
# 1. Set a “pretty” style
|
|
sns.set_theme(style="whitegrid", palette="pastel")
|
|
|
|
# 2. Create a figure with two side-by-side axes
|
|
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
|
|
|
|
# 3. First subplot: stats_donut
|
|
sns.lineplot(data=stats_donut, ax=axes[0])
|
|
axes[0].set_title("Donut Statistics", fontsize=14, fontweight="bold")
|
|
axes[0].set_xlabel("Index")
|
|
axes[0].set_ylabel("% Constraint satisfied")
|
|
axes[0].grid(True, linestyle="--", alpha=0.7)
|
|
|
|
# 4. Second subplot: stats_dent
|
|
sns.lineplot(data=stats_dent, ax=axes[1])
|
|
axes[1].set_title("Dent Statistics", fontsize=14, fontweight="bold")
|
|
axes[1].set_xlabel("Index")
|
|
axes[1].set_ylabel("% Constraint satisfied")
|
|
axes[1].grid(True, linestyle="--", alpha=0.7)
|
|
|
|
# 5. Add an overall title and tighten layout
|
|
fig.suptitle("Comparison of Donut vs. Dent Trends", fontsize=16, fontweight="bold")
|
|
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
|
|
|
|
save_folder = (
|
|
Path("results")
|
|
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}"
|
|
)
|
|
save_folder.mkdir(parents=True, exist_ok=True)
|
|
plt.savefig(os.path.join(save_folder, "constraint_stats.png"))
|
|
plt.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config: Config = tyro.cli(Config)
|
|
main(config)
|
|
|