|
|
|
@ -10,6 +10,7 @@ import matplotlib.pyplot as plt |
|
|
|
import pandas as pd |
|
|
|
from filelock import FileLock, Timeout |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import jax |
|
|
|
import jax.numpy as jnp |
|
|
|
import flax.nnx as nnx |
|
|
|
@ -18,13 +19,19 @@ import optax |
|
|
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1" |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
|
@dataclass(frozen=True) |
|
|
|
class Config: |
|
|
|
"""Flow/DDM training of a simple distribution.""" |
|
|
|
|
|
|
|
space_dimensions: int = 2 |
|
|
|
"""The dimensionality of the distribution's space.""" |
|
|
|
|
|
|
|
brownian_motion: bool = False |
|
|
|
"""Enables brownian motion.""" |
|
|
|
|
|
|
|
brownian_sigma: float = 0.1 |
|
|
|
"""The size of the brownian motion.""" |
|
|
|
|
|
|
|
num_hidden_layers: int = 4 |
|
|
|
"""Number of hidden layers in the MLP.""" |
|
|
|
|
|
|
|
@ -111,6 +118,59 @@ def check_dent_constraint(x: jax.Array, d: float) -> jax.Array: |
|
|
|
return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2)) |
|
|
|
|
|
|
|
|
|
|
|
def brownian_motion(key: jax.random.PRNGKey, x0: jax.Array, sigma: float) -> jax.Array: |
|
|
|
noise = sigma * jax.random.normal(key=key, shape=x0.shape) |
|
|
|
return x0 + noise |
|
|
|
|
|
|
|
|
|
|
|
@partial(jax.jit, static_argnums=(1,)) |
|
|
|
def brownian_motion_step( |
|
|
|
key: jax.random.PRNGKey, |
|
|
|
num_samples: int, |
|
|
|
r1: float, |
|
|
|
r2: float, |
|
|
|
d: float = 0.0, |
|
|
|
sigma: float = 0.1, |
|
|
|
num_check_pts: int = 8, |
|
|
|
) -> (jax.Array, jax.Array): |
|
|
|
key, key_init = jax.random.split(key) |
|
|
|
x0 = sample_p_data(key_init, num_samples, r1, r2, d) |
|
|
|
|
|
|
|
def points_satisfy(x: jax.Array) -> jax.Array: |
|
|
|
return jnp.logical_and( |
|
|
|
check_donut_constraint(x, r1, r2), check_dent_constraint(x, d) |
|
|
|
) |
|
|
|
|
|
|
|
ts = jnp.linspace(0.0, 1.0, num_check_pts + 2)[1:-1] |
|
|
|
|
|
|
|
def cond_fn(state): |
|
|
|
_, _, cs = state |
|
|
|
return jnp.any(~cs) |
|
|
|
|
|
|
|
def body_fn(state): |
|
|
|
key, x, cs = state |
|
|
|
key, _key = jax.random.split(key) |
|
|
|
x_prop = brownian_motion(_key, x0, sigma) |
|
|
|
ok_end = points_satisfy(x_prop) |
|
|
|
|
|
|
|
diff = x_prop - x |
|
|
|
pts = x[:, None, :] + diff[:, None, :] * ts[None, :, None] |
|
|
|
ok_interior = jnp.all( |
|
|
|
points_satisfy(pts.reshape(-1, pts.shape[-1])).reshape(num_samples, -1), |
|
|
|
axis=1, |
|
|
|
) |
|
|
|
satisfied = ok_end & ok_interior |
|
|
|
x = jnp.where(jnp.logical_and(~cs, satisfied)[:, None], x_prop, x) |
|
|
|
cs = jnp.logical_or(cs, satisfied) |
|
|
|
return key, x, cs |
|
|
|
|
|
|
|
init_cs = jnp.zeros((num_samples,), dtype=bool) |
|
|
|
init_state = (key, x0, init_cs) |
|
|
|
_, x1, _ = jax.lax.while_loop(cond_fn, body_fn, init_state) |
|
|
|
|
|
|
|
return x0, x1 |
|
|
|
|
|
|
|
|
|
|
|
def array_bool_statistics(x: jax.Array) -> float: |
|
|
|
"""Computes the % of True in a bool array.""" |
|
|
|
assert x.dtype == jnp.bool_ |
|
|
|
@ -119,7 +179,13 @@ def array_bool_statistics(x: jax.Array) -> float: |
|
|
|
|
|
|
|
|
|
|
|
def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0): |
|
|
|
p_data_key, p_t_key = jax.random.split(key) |
|
|
|
key, p_data_key, p_t_key = jax.random.split(key, 3) |
|
|
|
|
|
|
|
if config.brownian_motion: |
|
|
|
key, key_b = jax.random.split(key) |
|
|
|
z0 = sample_p_data( |
|
|
|
key, num_samples=32_768, r1=config.r1, r2=config.r2, d=config.d |
|
|
|
) |
|
|
|
|
|
|
|
samples = ode_trajectory( |
|
|
|
p_t_key, |
|
|
|
@ -127,6 +193,7 @@ def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int |
|
|
|
num_samples=32_768, |
|
|
|
sample_steps=config.sample_steps, |
|
|
|
space_dimensions=config.space_dimensions, |
|
|
|
cond=z0 if config.brownian_motion else None, |
|
|
|
) |
|
|
|
|
|
|
|
stats_donut = array_bool_statistics( |
|
|
|
@ -140,7 +207,8 @@ def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int |
|
|
|
plt.figure() |
|
|
|
sns.scatterplot(x=samples[:, 0], y=samples[:, 1], size=0.1) |
|
|
|
save_folder = ( |
|
|
|
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}" |
|
|
|
Path("results") |
|
|
|
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}" |
|
|
|
) |
|
|
|
save_folder.mkdir(parents=True, exist_ok=True) |
|
|
|
plt.savefig( |
|
|
|
@ -238,7 +306,7 @@ beta_grad = jax.jit(jax.vmap(beta_scalar_grad)) |
|
|
|
|
|
|
|
|
|
|
|
def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array: |
|
|
|
return x_t + h * model(x_t, t) |
|
|
|
return x_t[..., :2] + h * model(x_t, t) |
|
|
|
|
|
|
|
|
|
|
|
@partial(nnx.jit, static_argnums=(2, 3, 4)) |
|
|
|
@ -248,6 +316,7 @@ def ode_trajectory( |
|
|
|
num_samples: int, |
|
|
|
sample_steps: int, |
|
|
|
space_dimensions: int, |
|
|
|
cond: Optional[jax.Array] = None, |
|
|
|
) -> jax.Array: |
|
|
|
t = jnp.zeros((num_samples,)) |
|
|
|
h = 1.0 / sample_steps |
|
|
|
@ -255,15 +324,11 @@ def ode_trajectory( |
|
|
|
|
|
|
|
def body(i, state): |
|
|
|
t, x = state |
|
|
|
x = ode_step(model, x, t, h) |
|
|
|
x_cond = jnp.concat((x, cond), axis=-1) if cond is not None else x |
|
|
|
x = ode_step(model, x_cond, t, h) |
|
|
|
return (t + h, x) |
|
|
|
|
|
|
|
_, x = jax.lax.fori_loop(0, config.sample_steps, body, (t, x)) |
|
|
|
|
|
|
|
# for i in range(sample_steps): |
|
|
|
# x = ode_step(model, x, t, h) |
|
|
|
# t = t + h |
|
|
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@ -278,20 +343,67 @@ def sde_trajectory(model: MLP) -> jax.Array: |
|
|
|
# --- Training ---------------------------------------- |
|
|
|
|
|
|
|
|
|
|
|
def main(config: Config): |
|
|
|
rngs = nnx.Rngs(config.seed) |
|
|
|
@nnx.jit |
|
|
|
def train_step(model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey): |
|
|
|
key_e, key_t = jax.random.split(key, 2) |
|
|
|
eps = jax.random.normal(key=key_e, shape=z.shape) |
|
|
|
t = jax.random.uniform(key=key_t, shape=[z.shape[0]]) |
|
|
|
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps |
|
|
|
|
|
|
|
def loss_fn(model, z, t, eps): |
|
|
|
loss = jnp.sum( |
|
|
|
(model(x, t) - (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps)) |
|
|
|
** 2, |
|
|
|
axis=-1, |
|
|
|
) |
|
|
|
return jnp.mean(loss) |
|
|
|
|
|
|
|
value_grad_fn = nnx.value_and_grad(loss_fn) |
|
|
|
loss, grads = value_grad_fn(model, z, t, eps) |
|
|
|
|
|
|
|
optim.update(grads) |
|
|
|
return loss |
|
|
|
|
|
|
|
if config.show_p_data: |
|
|
|
points = sample_p_data(rngs.params(), 32_768, config.r1, config.r2, config.d) |
|
|
|
plt.figure() |
|
|
|
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1) |
|
|
|
if not os.path.exists("./results"): |
|
|
|
os.mkdir("results") |
|
|
|
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight") |
|
|
|
exit() |
|
|
|
|
|
|
|
@nnx.jit |
|
|
|
def train_step_brownian( |
|
|
|
model: MLP, |
|
|
|
optim: nnx.Optimizer, |
|
|
|
z: (jax.Array, jax.Array), |
|
|
|
key: jax.random.PRNGKey, |
|
|
|
): |
|
|
|
z0, z1 = z |
|
|
|
assert z0.shape == z1.shape |
|
|
|
|
|
|
|
key_e, key_t = jax.random.split(key) |
|
|
|
eps = jax.random.normal(key=key_e, shape=z1.shape) |
|
|
|
t = jax.random.uniform(key=key_t, shape=[z1.shape[0]]) |
|
|
|
x = alpha(t)[:, None] * z1 + beta(t)[:, None] * eps |
|
|
|
|
|
|
|
def loss_fn(model): |
|
|
|
x_brownian = jnp.concat((x, z0), axis=-1) |
|
|
|
loss = jnp.sum( |
|
|
|
( |
|
|
|
model(x_brownian, t) |
|
|
|
- (alpha_grad(t)[:, None] * beta_grad(t)[:, None] * eps) |
|
|
|
) |
|
|
|
** 2, |
|
|
|
axis=-1, |
|
|
|
) |
|
|
|
return jnp.mean(loss) |
|
|
|
|
|
|
|
value_grad_fn = nnx.value_and_grad(loss_fn) |
|
|
|
loss, grads = value_grad_fn(model) |
|
|
|
|
|
|
|
optim.update(grads) |
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
def setup_model_and_optimizer(config: Config, rngs: nnx.Rngs) -> (MLP, nnx.Optimizer): |
|
|
|
model = MLP( |
|
|
|
in_features=config.space_dimensions, |
|
|
|
in_features=2 * config.space_dimensions |
|
|
|
if config.brownian_motion |
|
|
|
else config.space_dimensions, |
|
|
|
out_features=config.space_dimensions, |
|
|
|
num_hidden_layers=config.num_hidden_layers, |
|
|
|
hidden_size=config.hidden_size, |
|
|
|
@ -305,51 +417,133 @@ def main(config: Config): |
|
|
|
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)), |
|
|
|
) |
|
|
|
|
|
|
|
@nnx.jit |
|
|
|
def train_step( |
|
|
|
model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey |
|
|
|
): |
|
|
|
key_e, key_t = jax.random.split(key) |
|
|
|
eps = jax.random.normal(key=key_e, shape=z.shape) |
|
|
|
t = jax.random.uniform(key=key_t, shape=[z.shape[0]]) |
|
|
|
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps |
|
|
|
|
|
|
|
def loss_fn(model, z, t, eps): |
|
|
|
loss = jnp.sum( |
|
|
|
( |
|
|
|
model(x, t) |
|
|
|
- (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps) |
|
|
|
return model, optim |
|
|
|
|
|
|
|
|
|
|
|
def update_experiments_log(i, stats_donut, stats_dent, config: Config) -> None: |
|
|
|
try: |
|
|
|
with FileLock(os.path.join("results", "experiments.lock")): |
|
|
|
file = os.path.join("results", "experiments.csv") |
|
|
|
if os.path.exists(file): |
|
|
|
df = pd.read_csv(file) |
|
|
|
else: |
|
|
|
df = pd.DataFrame( |
|
|
|
columns=[ |
|
|
|
"num_hidden_layers", |
|
|
|
"hidden_size", |
|
|
|
"mlp_bias", |
|
|
|
"fourier_dim", |
|
|
|
"r1", |
|
|
|
"r2", |
|
|
|
"d", |
|
|
|
"sample_steps", |
|
|
|
"p_donut", |
|
|
|
"p_dent", |
|
|
|
"step", |
|
|
|
"brownian", |
|
|
|
] |
|
|
|
) |
|
|
|
** 2, |
|
|
|
axis=-1, |
|
|
|
|
|
|
|
new_row = pd.DataFrame( |
|
|
|
{ |
|
|
|
"num_hidden_layers": [config.num_hidden_layers], |
|
|
|
"hidden_size": [config.hidden_size], |
|
|
|
"mlp_bias": [config.mlp_bias], |
|
|
|
"fourier_dim": [config.fourier_dim], |
|
|
|
"r1": [config.r1], |
|
|
|
"r2": [config.r2], |
|
|
|
"d": [config.d], |
|
|
|
"sample_steps": [config.sample_steps], |
|
|
|
"p_donut": [stats_donut[-1]], |
|
|
|
"p_dent": [stats_dent[-1]], |
|
|
|
"step": [i], |
|
|
|
"brownian": [config.brownian_motion], |
|
|
|
} |
|
|
|
) |
|
|
|
return jnp.mean(loss) |
|
|
|
|
|
|
|
value_grad_fn = nnx.value_and_grad(loss_fn) |
|
|
|
loss, grads = value_grad_fn(model, z, t, eps) |
|
|
|
df = pd.concat([df, new_row], ignore_index=True) |
|
|
|
df.to_csv(file, index=False) |
|
|
|
except Timeout: |
|
|
|
print("Timeout!!!") |
|
|
|
|
|
|
|
|
|
|
|
def plot_p_data(key: jax.random.PRNGKey, config: Config): |
|
|
|
if not config.brownian_motion: |
|
|
|
points = sample_p_data(key, config.batch_size, config.r1, config.r2, config.d) |
|
|
|
plt.figure() |
|
|
|
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1) |
|
|
|
if not os.path.exists("./results"): |
|
|
|
os.mkdir("results") |
|
|
|
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight") |
|
|
|
else: |
|
|
|
x0, x1 = brownian_motion_step( |
|
|
|
key, |
|
|
|
config.batch_size, |
|
|
|
config.r1, |
|
|
|
config.r2, |
|
|
|
config.d, |
|
|
|
config.brownian_sigma, |
|
|
|
) |
|
|
|
|
|
|
|
optim.update(grads) |
|
|
|
return loss |
|
|
|
# If you have x0, x1 as JAX arrays, first convert: |
|
|
|
x0_np = np.array(x0) # shape (1024,2) |
|
|
|
x1_np = np.array(x1) # shape (1024,2) |
|
|
|
|
|
|
|
cached_train_step = nnx.cached_partial(train_step, model, optim) |
|
|
|
plt.figure(figsize=(8, 8)) |
|
|
|
|
|
|
|
# draw a tiny line for each sample |
|
|
|
for start, end in zip(x0_np, x1_np): |
|
|
|
plt.plot([start[0], end[0]], [start[1], end[1]], linewidth=0.5, alpha=0.5) |
|
|
|
|
|
|
|
# scatter the start points |
|
|
|
plt.scatter(x0_np[:, 0], x0_np[:, 1], s=10, label="x₀") |
|
|
|
|
|
|
|
# scatter the end points |
|
|
|
plt.scatter(x1_np[:, 0], x1_np[:, 1], s=10, label="x₁") |
|
|
|
|
|
|
|
plt.legend(loc="upper right") |
|
|
|
plt.axis("equal") |
|
|
|
plt.title("Brownian step in the dented annulus") |
|
|
|
plt.xlabel("x") |
|
|
|
plt.ylabel("y") |
|
|
|
plt.tight_layout() |
|
|
|
plt.savefig("results/brownian.png") |
|
|
|
|
|
|
|
|
|
|
|
def main(config: Config): |
|
|
|
rngs = nnx.Rngs(config.seed) |
|
|
|
|
|
|
|
if config.show_p_data: |
|
|
|
plot_p_data(rngs.params(), config) |
|
|
|
exit() |
|
|
|
|
|
|
|
model, optim = setup_model_and_optimizer(config, rngs) |
|
|
|
|
|
|
|
step_fn = train_step_brownian if config.brownian_motion else train_step |
|
|
|
step_fn = nnx.cached_partial(step_fn, model, optim) |
|
|
|
sampler = partial( |
|
|
|
( |
|
|
|
partial(brownian_motion_step, sigma=config.brownian_sigma) |
|
|
|
if config.brownian_motion |
|
|
|
else sample_p_data |
|
|
|
), |
|
|
|
num_samples=config.batch_size, |
|
|
|
r1=config.r1, |
|
|
|
r2=config.r2, |
|
|
|
d=config.d, |
|
|
|
) |
|
|
|
|
|
|
|
a_donut = (jnp.pi - 0.5 * config.d) * (config.r2**2 - config.r1**2) |
|
|
|
a_dent = 0.5 * config.d * (config.r2**2 - config.r1**2) |
|
|
|
stats_donut = [] |
|
|
|
stats_dent = [] |
|
|
|
for i in tqdm(range(config.num_steps)): |
|
|
|
z = sample_p_data( |
|
|
|
rngs.params(), |
|
|
|
num_samples=config.batch_size, |
|
|
|
r1=config.r1, |
|
|
|
r2=config.r2, |
|
|
|
d=config.d, |
|
|
|
) |
|
|
|
|
|
|
|
_ = cached_train_step(z, rngs.params()) |
|
|
|
z = sampler(rngs.params()) |
|
|
|
_ = step_fn(z, rngs.params()) |
|
|
|
|
|
|
|
if ( |
|
|
|
config.evaluate_constraints_every != 0 |
|
|
|
and i > 0 |
|
|
|
and i % config.evaluate_constraints_every == 0 |
|
|
|
): |
|
|
|
stat_donut, stat_dent = evaluate_constraints( |
|
|
|
@ -358,48 +552,9 @@ def main(config: Config): |
|
|
|
stats_donut.append(stat_donut.item() / a_donut) |
|
|
|
stats_dent.append(stat_dent.item() / a_dent) |
|
|
|
|
|
|
|
try: |
|
|
|
with FileLock(os.path.join("results", "experiments.lock")): |
|
|
|
file = os.path.join("results", "experiments.csv") |
|
|
|
if os.path.exists(file): |
|
|
|
df = pd.read_csv(file) |
|
|
|
else: |
|
|
|
df = pd.DataFrame( |
|
|
|
columns=[ |
|
|
|
"num_hidden_layers", |
|
|
|
"hidden_size", |
|
|
|
"mlp_bias", |
|
|
|
"fourier_dim", |
|
|
|
"r1", |
|
|
|
"r2", |
|
|
|
"d", |
|
|
|
"sample_steps", |
|
|
|
"p_donut", |
|
|
|
"p_dent", |
|
|
|
"step", |
|
|
|
] |
|
|
|
) |
|
|
|
|
|
|
|
new_row = pd.DataFrame( |
|
|
|
{ |
|
|
|
"num_hidden_layers": [config.num_hidden_layers], |
|
|
|
"hidden_size": [config.hidden_size], |
|
|
|
"mlp_bias": [config.mlp_bias], |
|
|
|
"fourier_dim": [config.fourier_dim], |
|
|
|
"r1": [config.r1], |
|
|
|
"r2": [config.r2], |
|
|
|
"d": [config.d], |
|
|
|
"sample_steps": [config.sample_steps], |
|
|
|
"p_donut": [stats_donut[-1]], |
|
|
|
"p_dent": [stats_dent[-1]], |
|
|
|
"step": [i], |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
df = pd.concat([df, new_row], ignore_index=True) |
|
|
|
df.to_csv(file, index=False) |
|
|
|
except Timeout: |
|
|
|
print("Timeout!!!") |
|
|
|
update_experiments_log( |
|
|
|
i=i, stats_donut=stats_donut, stats_dent=stats_dent, config=config |
|
|
|
) |
|
|
|
|
|
|
|
# Plot results |
|
|
|
# 1. Set a “pretty” style |
|
|
|
@ -427,7 +582,8 @@ def main(config: Config): |
|
|
|
fig.tight_layout(rect=[0, 0.03, 1, 0.95]) |
|
|
|
|
|
|
|
save_folder = ( |
|
|
|
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}" |
|
|
|
Path("results") |
|
|
|
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}" |
|
|
|
) |
|
|
|
save_folder.mkdir(parents=True, exist_ok=True) |
|
|
|
plt.savefig(os.path.join(save_folder, "constraint_stats.png")) |
|
|
|
@ -435,5 +591,5 @@ def main(config: Config): |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
config = tyro.cli(Config) |
|
|
|
config: Config = tyro.cli(Config) |
|
|
|
main(config) |
|
|
|
|