Browse Source

feat: brownian motion

master
CALVO GONZALEZ Ramon 9 months ago
parent
commit
b4652896c0
  1. 50
      log.typ
  2. 294
      train.py

50
log.typ

@ -0,0 +1,50 @@
#import "@preview/typslides:1.2.5": *
#show: typslides.with(
ratio: "16-9",
theme: "bluey",
)
#front-slide(
title: "Diffusion and joint laws",
subtitle: ["And how time might help learn them"],
authors: ("Ramon"),
info: [#link("https://unige.ch")],
)
// #table-of-contents()
//#title-slide[
//Section 1: Introduction
//]
//#slide[
// == Overview
// - Background and motivation
//- Research question
// - Objectives
//]
#slide(title: "Data generation process")[
- Visualization of $p_"init"$. We want the diffusion te learn the 2 constraints (annulus and wedge).
#image("./results/p_data.png")
]
#slide(title: "Norm. flow after 500.000 steps")[
- Samples drawn from the Flow after training for 500.000 steps
#image("./results/mlp_l4_h64/scatter_500000.png")
]
#slide(title: "Brownian motion p_init")[
- Generation process: sample form $z_0 ~ p_"init"$, do 1 brownian step and $z_1 ~ z_0 + cal(N)(0, sigma^2)$
#image("./results/brownian.png")
]
#slide(title: "Brownian motion training")[
- The new Flow Matching loss becomes:
$
cal(L)_"CFM" = EE_(t ~ cal(U)(0, 1), z_0 ~ p_"data", z_1 ~ z_0 + cal(N)(0, sigma^2), x ~ p_t (dot, z_1))[ ||u_t^theta (x | z_0) - u_t^"target" (x | z_1) ||^2 ]
$
]

294
train.py

@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
from filelock import FileLock, Timeout from filelock import FileLock, Timeout
import numpy as np
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax.nnx as nnx import flax.nnx as nnx
@ -18,13 +19,19 @@ import optax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1" os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1"
@dataclass @dataclass(frozen=True)
class Config: class Config:
"""Flow/DDM training of a simple distribution.""" """Flow/DDM training of a simple distribution."""
space_dimensions: int = 2 space_dimensions: int = 2
"""The dimensionality of the distribution's space.""" """The dimensionality of the distribution's space."""
brownian_motion: bool = False
"""Enables brownian motion."""
brownian_sigma: float = 0.1
"""The size of the brownian motion."""
num_hidden_layers: int = 4 num_hidden_layers: int = 4
"""Number of hidden layers in the MLP.""" """Number of hidden layers in the MLP."""
@ -111,6 +118,59 @@ def check_dent_constraint(x: jax.Array, d: float) -> jax.Array:
return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2)) return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2))
def brownian_motion(key: jax.random.PRNGKey, x0: jax.Array, sigma: float) -> jax.Array:
noise = sigma * jax.random.normal(key=key, shape=x0.shape)
return x0 + noise
@partial(jax.jit, static_argnums=(1,))
def brownian_motion_step(
key: jax.random.PRNGKey,
num_samples: int,
r1: float,
r2: float,
d: float = 0.0,
sigma: float = 0.1,
num_check_pts: int = 8,
) -> (jax.Array, jax.Array):
key, key_init = jax.random.split(key)
x0 = sample_p_data(key_init, num_samples, r1, r2, d)
def points_satisfy(x: jax.Array) -> jax.Array:
return jnp.logical_and(
check_donut_constraint(x, r1, r2), check_dent_constraint(x, d)
)
ts = jnp.linspace(0.0, 1.0, num_check_pts + 2)[1:-1]
def cond_fn(state):
_, _, cs = state
return jnp.any(~cs)
def body_fn(state):
key, x, cs = state
key, _key = jax.random.split(key)
x_prop = brownian_motion(_key, x0, sigma)
ok_end = points_satisfy(x_prop)
diff = x_prop - x
pts = x[:, None, :] + diff[:, None, :] * ts[None, :, None]
ok_interior = jnp.all(
points_satisfy(pts.reshape(-1, pts.shape[-1])).reshape(num_samples, -1),
axis=1,
)
satisfied = ok_end & ok_interior
x = jnp.where(jnp.logical_and(~cs, satisfied)[:, None], x_prop, x)
cs = jnp.logical_or(cs, satisfied)
return key, x, cs
init_cs = jnp.zeros((num_samples,), dtype=bool)
init_state = (key, x0, init_cs)
_, x1, _ = jax.lax.while_loop(cond_fn, body_fn, init_state)
return x0, x1
def array_bool_statistics(x: jax.Array) -> float: def array_bool_statistics(x: jax.Array) -> float:
"""Computes the % of True in a bool array.""" """Computes the % of True in a bool array."""
assert x.dtype == jnp.bool_ assert x.dtype == jnp.bool_
@ -119,7 +179,13 @@ def array_bool_statistics(x: jax.Array) -> float:
def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0): def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0):
p_data_key, p_t_key = jax.random.split(key) key, p_data_key, p_t_key = jax.random.split(key, 3)
if config.brownian_motion:
key, key_b = jax.random.split(key)
z0 = sample_p_data(
key, num_samples=32_768, r1=config.r1, r2=config.r2, d=config.d
)
samples = ode_trajectory( samples = ode_trajectory(
p_t_key, p_t_key,
@ -127,6 +193,7 @@ def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int
num_samples=32_768, num_samples=32_768,
sample_steps=config.sample_steps, sample_steps=config.sample_steps,
space_dimensions=config.space_dimensions, space_dimensions=config.space_dimensions,
cond=z0 if config.brownian_motion else None,
) )
stats_donut = array_bool_statistics( stats_donut = array_bool_statistics(
@ -140,7 +207,8 @@ def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int
plt.figure() plt.figure()
sns.scatterplot(x=samples[:, 0], y=samples[:, 1], size=0.1) sns.scatterplot(x=samples[:, 0], y=samples[:, 1], size=0.1)
save_folder = ( save_folder = (
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}" Path("results")
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}"
) )
save_folder.mkdir(parents=True, exist_ok=True) save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig( plt.savefig(
@ -238,7 +306,7 @@ beta_grad = jax.jit(jax.vmap(beta_scalar_grad))
def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array: def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array:
return x_t + h * model(x_t, t) return x_t[..., :2] + h * model(x_t, t)
@partial(nnx.jit, static_argnums=(2, 3, 4)) @partial(nnx.jit, static_argnums=(2, 3, 4))
@ -248,6 +316,7 @@ def ode_trajectory(
num_samples: int, num_samples: int,
sample_steps: int, sample_steps: int,
space_dimensions: int, space_dimensions: int,
cond: Optional[jax.Array] = None,
) -> jax.Array: ) -> jax.Array:
t = jnp.zeros((num_samples,)) t = jnp.zeros((num_samples,))
h = 1.0 / sample_steps h = 1.0 / sample_steps
@ -255,15 +324,11 @@ def ode_trajectory(
def body(i, state): def body(i, state):
t, x = state t, x = state
x = ode_step(model, x, t, h) x_cond = jnp.concat((x, cond), axis=-1) if cond is not None else x
x = ode_step(model, x_cond, t, h)
return (t + h, x) return (t + h, x)
_, x = jax.lax.fori_loop(0, config.sample_steps, body, (t, x)) _, x = jax.lax.fori_loop(0, config.sample_steps, body, (t, x))
# for i in range(sample_steps):
# x = ode_step(model, x, t, h)
# t = t + h
return x return x
@ -278,48 +343,16 @@ def sde_trajectory(model: MLP) -> jax.Array:
# --- Training ---------------------------------------- # --- Training ----------------------------------------
def main(config: Config): @nnx.jit
rngs = nnx.Rngs(config.seed) def train_step(model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey):
key_e, key_t = jax.random.split(key, 2)
if config.show_p_data:
points = sample_p_data(rngs.params(), 32_768, config.r1, config.r2, config.d)
plt.figure()
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1)
if not os.path.exists("./results"):
os.mkdir("results")
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight")
exit()
model = MLP(
in_features=config.space_dimensions,
out_features=config.space_dimensions,
num_hidden_layers=config.num_hidden_layers,
hidden_size=config.hidden_size,
use_bias=config.mlp_bias,
fourier_features=config.fourier_dim,
fourier_max_period=config.fourier_max_period,
rngs=rngs,
)
optim = nnx.Optimizer(
model,
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)),
)
@nnx.jit
def train_step(
model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey
):
key_e, key_t = jax.random.split(key)
eps = jax.random.normal(key=key_e, shape=z.shape) eps = jax.random.normal(key=key_e, shape=z.shape)
t = jax.random.uniform(key=key_t, shape=[z.shape[0]]) t = jax.random.uniform(key=key_t, shape=[z.shape[0]])
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps x = alpha(t)[:, None] * z + beta(t)[:, None] * eps
def loss_fn(model, z, t, eps): def loss_fn(model, z, t, eps):
loss = jnp.sum( loss = jnp.sum(
( (model(x, t) - (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps))
model(x, t)
- (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps)
)
** 2, ** 2,
axis=-1, axis=-1,
) )
@ -331,33 +364,63 @@ def main(config: Config):
optim.update(grads) optim.update(grads)
return loss return loss
cached_train_step = nnx.cached_partial(train_step, model, optim)
a_donut = (jnp.pi - 0.5 * config.d) * (config.r2**2 - config.r1**2) @nnx.jit
a_dent = 0.5 * config.d * (config.r2**2 - config.r1**2) def train_step_brownian(
stats_donut = [] model: MLP,
stats_dent = [] optim: nnx.Optimizer,
for i in tqdm(range(config.num_steps)): z: (jax.Array, jax.Array),
z = sample_p_data( key: jax.random.PRNGKey,
rngs.params(), ):
num_samples=config.batch_size, z0, z1 = z
r1=config.r1, assert z0.shape == z1.shape
r2=config.r2,
d=config.d, key_e, key_t = jax.random.split(key)
eps = jax.random.normal(key=key_e, shape=z1.shape)
t = jax.random.uniform(key=key_t, shape=[z1.shape[0]])
x = alpha(t)[:, None] * z1 + beta(t)[:, None] * eps
def loss_fn(model):
x_brownian = jnp.concat((x, z0), axis=-1)
loss = jnp.sum(
(
model(x_brownian, t)
- (alpha_grad(t)[:, None] * beta_grad(t)[:, None] * eps)
)
** 2,
axis=-1,
) )
return jnp.mean(loss)
_ = cached_train_step(z, rngs.params()) value_grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = value_grad_fn(model)
if ( optim.update(grads)
config.evaluate_constraints_every != 0 return loss
and i % config.evaluate_constraints_every == 0
):
stat_donut, stat_dent = evaluate_constraints( def setup_model_and_optimizer(config: Config, rngs: nnx.Rngs) -> (MLP, nnx.Optimizer):
rngs.params(), model, save_plot=i if config.save_scatterplot else None model = MLP(
in_features=2 * config.space_dimensions
if config.brownian_motion
else config.space_dimensions,
out_features=config.space_dimensions,
num_hidden_layers=config.num_hidden_layers,
hidden_size=config.hidden_size,
use_bias=config.mlp_bias,
fourier_features=config.fourier_dim,
fourier_max_period=config.fourier_max_period,
rngs=rngs,
) )
stats_donut.append(stat_donut.item() / a_donut) optim = nnx.Optimizer(
stats_dent.append(stat_dent.item() / a_dent) model,
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)),
)
return model, optim
def update_experiments_log(i, stats_donut, stats_dent, config: Config) -> None:
try: try:
with FileLock(os.path.join("results", "experiments.lock")): with FileLock(os.path.join("results", "experiments.lock")):
file = os.path.join("results", "experiments.csv") file = os.path.join("results", "experiments.csv")
@ -377,6 +440,7 @@ def main(config: Config):
"p_donut", "p_donut",
"p_dent", "p_dent",
"step", "step",
"brownian",
] ]
) )
@ -393,6 +457,7 @@ def main(config: Config):
"p_donut": [stats_donut[-1]], "p_donut": [stats_donut[-1]],
"p_dent": [stats_dent[-1]], "p_dent": [stats_dent[-1]],
"step": [i], "step": [i],
"brownian": [config.brownian_motion],
} }
) )
@ -401,6 +466,96 @@ def main(config: Config):
except Timeout: except Timeout:
print("Timeout!!!") print("Timeout!!!")
def plot_p_data(key: jax.random.PRNGKey, config: Config):
if not config.brownian_motion:
points = sample_p_data(key, config.batch_size, config.r1, config.r2, config.d)
plt.figure()
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1)
if not os.path.exists("./results"):
os.mkdir("results")
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight")
else:
x0, x1 = brownian_motion_step(
key,
config.batch_size,
config.r1,
config.r2,
config.d,
config.brownian_sigma,
)
# If you have x0, x1 as JAX arrays, first convert:
x0_np = np.array(x0) # shape (1024,2)
x1_np = np.array(x1) # shape (1024,2)
plt.figure(figsize=(8, 8))
# draw a tiny line for each sample
for start, end in zip(x0_np, x1_np):
plt.plot([start[0], end[0]], [start[1], end[1]], linewidth=0.5, alpha=0.5)
# scatter the start points
plt.scatter(x0_np[:, 0], x0_np[:, 1], s=10, label="x₀")
# scatter the end points
plt.scatter(x1_np[:, 0], x1_np[:, 1], s=10, label="x₁")
plt.legend(loc="upper right")
plt.axis("equal")
plt.title("Brownian step in the dented annulus")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.savefig("results/brownian.png")
def main(config: Config):
rngs = nnx.Rngs(config.seed)
if config.show_p_data:
plot_p_data(rngs.params(), config)
exit()
model, optim = setup_model_and_optimizer(config, rngs)
step_fn = train_step_brownian if config.brownian_motion else train_step
step_fn = nnx.cached_partial(step_fn, model, optim)
sampler = partial(
(
partial(brownian_motion_step, sigma=config.brownian_sigma)
if config.brownian_motion
else sample_p_data
),
num_samples=config.batch_size,
r1=config.r1,
r2=config.r2,
d=config.d,
)
a_donut = (jnp.pi - 0.5 * config.d) * (config.r2**2 - config.r1**2)
a_dent = 0.5 * config.d * (config.r2**2 - config.r1**2)
stats_donut = []
stats_dent = []
for i in tqdm(range(config.num_steps)):
z = sampler(rngs.params())
_ = step_fn(z, rngs.params())
if (
config.evaluate_constraints_every != 0
and i > 0
and i % config.evaluate_constraints_every == 0
):
stat_donut, stat_dent = evaluate_constraints(
rngs.params(), model, save_plot=i if config.save_scatterplot else None
)
stats_donut.append(stat_donut.item() / a_donut)
stats_dent.append(stat_dent.item() / a_dent)
update_experiments_log(
i=i, stats_donut=stats_donut, stats_dent=stats_dent, config=config
)
# Plot results # Plot results
# 1. Set a “pretty” style # 1. Set a “pretty” style
sns.set_theme(style="whitegrid", palette="pastel") sns.set_theme(style="whitegrid", palette="pastel")
@ -427,7 +582,8 @@ def main(config: Config):
fig.tight_layout(rect=[0, 0.03, 1, 0.95]) fig.tight_layout(rect=[0, 0.03, 1, 0.95])
save_folder = ( save_folder = (
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}" Path("results")
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}"
) )
save_folder.mkdir(parents=True, exist_ok=True) save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig(os.path.join(save_folder, "constraint_stats.png")) plt.savefig(os.path.join(save_folder, "constraint_stats.png"))
@ -435,5 +591,5 @@ def main(config: Config):
if __name__ == "__main__": if __name__ == "__main__":
config = tyro.cli(Config) config: Config = tyro.cli(Config)
main(config) main(config)

Loading…
Cancel
Save