Browse Source

feat: brownian motion

master
CALVO GONZALEZ Ramon 7 months ago
parent
commit
b4652896c0
  1. 50
      log.typ
  2. 348
      train.py

50
log.typ

@ -0,0 +1,50 @@
#import "@preview/typslides:1.2.5": *
#show: typslides.with(
ratio: "16-9",
theme: "bluey",
)
#front-slide(
title: "Diffusion and joint laws",
subtitle: ["And how time might help learn them"],
authors: ("Ramon"),
info: [#link("https://unige.ch")],
)
// #table-of-contents()
//#title-slide[
//Section 1: Introduction
//]
//#slide[
// == Overview
// - Background and motivation
//- Research question
// - Objectives
//]
#slide(title: "Data generation process")[
- Visualization of $p_"init"$. We want the diffusion te learn the 2 constraints (annulus and wedge).
#image("./results/p_data.png")
]
#slide(title: "Norm. flow after 500.000 steps")[
- Samples drawn from the Flow after training for 500.000 steps
#image("./results/mlp_l4_h64/scatter_500000.png")
]
#slide(title: "Brownian motion p_init")[
- Generation process: sample form $z_0 ~ p_"init"$, do 1 brownian step and $z_1 ~ z_0 + cal(N)(0, sigma^2)$
#image("./results/brownian.png")
]
#slide(title: "Brownian motion training")[
- The new Flow Matching loss becomes:
$
cal(L)_"CFM" = EE_(t ~ cal(U)(0, 1), z_0 ~ p_"data", z_1 ~ z_0 + cal(N)(0, sigma^2), x ~ p_t (dot, z_1))[ ||u_t^theta (x | z_0) - u_t^"target" (x | z_1) ||^2 ]
$
]

348
train.py

@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
import pandas as pd
from filelock import FileLock, Timeout
import numpy as np
import jax
import jax.numpy as jnp
import flax.nnx as nnx
@ -18,13 +19,19 @@ import optax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1"
@dataclass
@dataclass(frozen=True)
class Config:
"""Flow/DDM training of a simple distribution."""
space_dimensions: int = 2
"""The dimensionality of the distribution's space."""
brownian_motion: bool = False
"""Enables brownian motion."""
brownian_sigma: float = 0.1
"""The size of the brownian motion."""
num_hidden_layers: int = 4
"""Number of hidden layers in the MLP."""
@ -111,6 +118,59 @@ def check_dent_constraint(x: jax.Array, d: float) -> jax.Array:
return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2))
def brownian_motion(key: jax.random.PRNGKey, x0: jax.Array, sigma: float) -> jax.Array:
noise = sigma * jax.random.normal(key=key, shape=x0.shape)
return x0 + noise
@partial(jax.jit, static_argnums=(1,))
def brownian_motion_step(
key: jax.random.PRNGKey,
num_samples: int,
r1: float,
r2: float,
d: float = 0.0,
sigma: float = 0.1,
num_check_pts: int = 8,
) -> (jax.Array, jax.Array):
key, key_init = jax.random.split(key)
x0 = sample_p_data(key_init, num_samples, r1, r2, d)
def points_satisfy(x: jax.Array) -> jax.Array:
return jnp.logical_and(
check_donut_constraint(x, r1, r2), check_dent_constraint(x, d)
)
ts = jnp.linspace(0.0, 1.0, num_check_pts + 2)[1:-1]
def cond_fn(state):
_, _, cs = state
return jnp.any(~cs)
def body_fn(state):
key, x, cs = state
key, _key = jax.random.split(key)
x_prop = brownian_motion(_key, x0, sigma)
ok_end = points_satisfy(x_prop)
diff = x_prop - x
pts = x[:, None, :] + diff[:, None, :] * ts[None, :, None]
ok_interior = jnp.all(
points_satisfy(pts.reshape(-1, pts.shape[-1])).reshape(num_samples, -1),
axis=1,
)
satisfied = ok_end & ok_interior
x = jnp.where(jnp.logical_and(~cs, satisfied)[:, None], x_prop, x)
cs = jnp.logical_or(cs, satisfied)
return key, x, cs
init_cs = jnp.zeros((num_samples,), dtype=bool)
init_state = (key, x0, init_cs)
_, x1, _ = jax.lax.while_loop(cond_fn, body_fn, init_state)
return x0, x1
def array_bool_statistics(x: jax.Array) -> float:
"""Computes the % of True in a bool array."""
assert x.dtype == jnp.bool_
@ -119,7 +179,13 @@ def array_bool_statistics(x: jax.Array) -> float:
def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0):
p_data_key, p_t_key = jax.random.split(key)
key, p_data_key, p_t_key = jax.random.split(key, 3)
if config.brownian_motion:
key, key_b = jax.random.split(key)
z0 = sample_p_data(
key, num_samples=32_768, r1=config.r1, r2=config.r2, d=config.d
)
samples = ode_trajectory(
p_t_key,
@ -127,6 +193,7 @@ def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int
num_samples=32_768,
sample_steps=config.sample_steps,
space_dimensions=config.space_dimensions,
cond=z0 if config.brownian_motion else None,
)
stats_donut = array_bool_statistics(
@ -140,7 +207,8 @@ def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int
plt.figure()
sns.scatterplot(x=samples[:, 0], y=samples[:, 1], size=0.1)
save_folder = (
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}"
Path("results")
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}"
)
save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig(
@ -238,7 +306,7 @@ beta_grad = jax.jit(jax.vmap(beta_scalar_grad))
def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array:
return x_t + h * model(x_t, t)
return x_t[..., :2] + h * model(x_t, t)
@partial(nnx.jit, static_argnums=(2, 3, 4))
@ -248,6 +316,7 @@ def ode_trajectory(
num_samples: int,
sample_steps: int,
space_dimensions: int,
cond: Optional[jax.Array] = None,
) -> jax.Array:
t = jnp.zeros((num_samples,))
h = 1.0 / sample_steps
@ -255,15 +324,11 @@ def ode_trajectory(
def body(i, state):
t, x = state
x = ode_step(model, x, t, h)
x_cond = jnp.concat((x, cond), axis=-1) if cond is not None else x
x = ode_step(model, x_cond, t, h)
return (t + h, x)
_, x = jax.lax.fori_loop(0, config.sample_steps, body, (t, x))
# for i in range(sample_steps):
# x = ode_step(model, x, t, h)
# t = t + h
return x
@ -278,20 +343,67 @@ def sde_trajectory(model: MLP) -> jax.Array:
# --- Training ----------------------------------------
def main(config: Config):
rngs = nnx.Rngs(config.seed)
@nnx.jit
def train_step(model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey):
key_e, key_t = jax.random.split(key, 2)
eps = jax.random.normal(key=key_e, shape=z.shape)
t = jax.random.uniform(key=key_t, shape=[z.shape[0]])
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps
def loss_fn(model, z, t, eps):
loss = jnp.sum(
(model(x, t) - (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps))
** 2,
axis=-1,
)
return jnp.mean(loss)
value_grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = value_grad_fn(model, z, t, eps)
optim.update(grads)
return loss
if config.show_p_data:
points = sample_p_data(rngs.params(), 32_768, config.r1, config.r2, config.d)
plt.figure()
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1)
if not os.path.exists("./results"):
os.mkdir("results")
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight")
exit()
@nnx.jit
def train_step_brownian(
model: MLP,
optim: nnx.Optimizer,
z: (jax.Array, jax.Array),
key: jax.random.PRNGKey,
):
z0, z1 = z
assert z0.shape == z1.shape
key_e, key_t = jax.random.split(key)
eps = jax.random.normal(key=key_e, shape=z1.shape)
t = jax.random.uniform(key=key_t, shape=[z1.shape[0]])
x = alpha(t)[:, None] * z1 + beta(t)[:, None] * eps
def loss_fn(model):
x_brownian = jnp.concat((x, z0), axis=-1)
loss = jnp.sum(
(
model(x_brownian, t)
- (alpha_grad(t)[:, None] * beta_grad(t)[:, None] * eps)
)
** 2,
axis=-1,
)
return jnp.mean(loss)
value_grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = value_grad_fn(model)
optim.update(grads)
return loss
def setup_model_and_optimizer(config: Config, rngs: nnx.Rngs) -> (MLP, nnx.Optimizer):
model = MLP(
in_features=config.space_dimensions,
in_features=2 * config.space_dimensions
if config.brownian_motion
else config.space_dimensions,
out_features=config.space_dimensions,
num_hidden_layers=config.num_hidden_layers,
hidden_size=config.hidden_size,
@ -305,51 +417,133 @@ def main(config: Config):
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)),
)
@nnx.jit
def train_step(
model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey
):
key_e, key_t = jax.random.split(key)
eps = jax.random.normal(key=key_e, shape=z.shape)
t = jax.random.uniform(key=key_t, shape=[z.shape[0]])
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps
def loss_fn(model, z, t, eps):
loss = jnp.sum(
(
model(x, t)
- (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps)
return model, optim
def update_experiments_log(i, stats_donut, stats_dent, config: Config) -> None:
try:
with FileLock(os.path.join("results", "experiments.lock")):
file = os.path.join("results", "experiments.csv")
if os.path.exists(file):
df = pd.read_csv(file)
else:
df = pd.DataFrame(
columns=[
"num_hidden_layers",
"hidden_size",
"mlp_bias",
"fourier_dim",
"r1",
"r2",
"d",
"sample_steps",
"p_donut",
"p_dent",
"step",
"brownian",
]
)
** 2,
axis=-1,
new_row = pd.DataFrame(
{
"num_hidden_layers": [config.num_hidden_layers],
"hidden_size": [config.hidden_size],
"mlp_bias": [config.mlp_bias],
"fourier_dim": [config.fourier_dim],
"r1": [config.r1],
"r2": [config.r2],
"d": [config.d],
"sample_steps": [config.sample_steps],
"p_donut": [stats_donut[-1]],
"p_dent": [stats_dent[-1]],
"step": [i],
"brownian": [config.brownian_motion],
}
)
return jnp.mean(loss)
value_grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = value_grad_fn(model, z, t, eps)
df = pd.concat([df, new_row], ignore_index=True)
df.to_csv(file, index=False)
except Timeout:
print("Timeout!!!")
def plot_p_data(key: jax.random.PRNGKey, config: Config):
if not config.brownian_motion:
points = sample_p_data(key, config.batch_size, config.r1, config.r2, config.d)
plt.figure()
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1)
if not os.path.exists("./results"):
os.mkdir("results")
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight")
else:
x0, x1 = brownian_motion_step(
key,
config.batch_size,
config.r1,
config.r2,
config.d,
config.brownian_sigma,
)
optim.update(grads)
return loss
# If you have x0, x1 as JAX arrays, first convert:
x0_np = np.array(x0) # shape (1024,2)
x1_np = np.array(x1) # shape (1024,2)
cached_train_step = nnx.cached_partial(train_step, model, optim)
plt.figure(figsize=(8, 8))
# draw a tiny line for each sample
for start, end in zip(x0_np, x1_np):
plt.plot([start[0], end[0]], [start[1], end[1]], linewidth=0.5, alpha=0.5)
# scatter the start points
plt.scatter(x0_np[:, 0], x0_np[:, 1], s=10, label="x₀")
# scatter the end points
plt.scatter(x1_np[:, 0], x1_np[:, 1], s=10, label="x₁")
plt.legend(loc="upper right")
plt.axis("equal")
plt.title("Brownian step in the dented annulus")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.savefig("results/brownian.png")
def main(config: Config):
rngs = nnx.Rngs(config.seed)
if config.show_p_data:
plot_p_data(rngs.params(), config)
exit()
model, optim = setup_model_and_optimizer(config, rngs)
step_fn = train_step_brownian if config.brownian_motion else train_step
step_fn = nnx.cached_partial(step_fn, model, optim)
sampler = partial(
(
partial(brownian_motion_step, sigma=config.brownian_sigma)
if config.brownian_motion
else sample_p_data
),
num_samples=config.batch_size,
r1=config.r1,
r2=config.r2,
d=config.d,
)
a_donut = (jnp.pi - 0.5 * config.d) * (config.r2**2 - config.r1**2)
a_dent = 0.5 * config.d * (config.r2**2 - config.r1**2)
stats_donut = []
stats_dent = []
for i in tqdm(range(config.num_steps)):
z = sample_p_data(
rngs.params(),
num_samples=config.batch_size,
r1=config.r1,
r2=config.r2,
d=config.d,
)
_ = cached_train_step(z, rngs.params())
z = sampler(rngs.params())
_ = step_fn(z, rngs.params())
if (
config.evaluate_constraints_every != 0
and i > 0
and i % config.evaluate_constraints_every == 0
):
stat_donut, stat_dent = evaluate_constraints(
@ -358,48 +552,9 @@ def main(config: Config):
stats_donut.append(stat_donut.item() / a_donut)
stats_dent.append(stat_dent.item() / a_dent)
try:
with FileLock(os.path.join("results", "experiments.lock")):
file = os.path.join("results", "experiments.csv")
if os.path.exists(file):
df = pd.read_csv(file)
else:
df = pd.DataFrame(
columns=[
"num_hidden_layers",
"hidden_size",
"mlp_bias",
"fourier_dim",
"r1",
"r2",
"d",
"sample_steps",
"p_donut",
"p_dent",
"step",
]
)
new_row = pd.DataFrame(
{
"num_hidden_layers": [config.num_hidden_layers],
"hidden_size": [config.hidden_size],
"mlp_bias": [config.mlp_bias],
"fourier_dim": [config.fourier_dim],
"r1": [config.r1],
"r2": [config.r2],
"d": [config.d],
"sample_steps": [config.sample_steps],
"p_donut": [stats_donut[-1]],
"p_dent": [stats_dent[-1]],
"step": [i],
}
)
df = pd.concat([df, new_row], ignore_index=True)
df.to_csv(file, index=False)
except Timeout:
print("Timeout!!!")
update_experiments_log(
i=i, stats_donut=stats_donut, stats_dent=stats_dent, config=config
)
# Plot results
# 1. Set a “pretty” style
@ -427,7 +582,8 @@ def main(config: Config):
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
save_folder = (
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}"
Path("results")
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}"
)
save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig(os.path.join(save_folder, "constraint_stats.png"))
@ -435,5 +591,5 @@ def main(config: Config):
if __name__ == "__main__":
config = tyro.cli(Config)
config: Config = tyro.cli(Config)
main(config)

Loading…
Cancel
Save