Browse Source

feat: brownian motion

master
CALVO GONZALEZ Ramon 9 months ago
parent
commit
b4652896c0
  1. 50
      log.typ
  2. 348
      train.py

50
log.typ

@ -0,0 +1,50 @@
#import "@preview/typslides:1.2.5": *
#show: typslides.with(
ratio: "16-9",
theme: "bluey",
)
#front-slide(
title: "Diffusion and joint laws",
subtitle: ["And how time might help learn them"],
authors: ("Ramon"),
info: [#link("https://unige.ch")],
)
// #table-of-contents()
//#title-slide[
//Section 1: Introduction
//]
//#slide[
// == Overview
// - Background and motivation
//- Research question
// - Objectives
//]
#slide(title: "Data generation process")[
- Visualization of $p_"init"$. We want the diffusion te learn the 2 constraints (annulus and wedge).
#image("./results/p_data.png")
]
#slide(title: "Norm. flow after 500.000 steps")[
- Samples drawn from the Flow after training for 500.000 steps
#image("./results/mlp_l4_h64/scatter_500000.png")
]
#slide(title: "Brownian motion p_init")[
- Generation process: sample form $z_0 ~ p_"init"$, do 1 brownian step and $z_1 ~ z_0 + cal(N)(0, sigma^2)$
#image("./results/brownian.png")
]
#slide(title: "Brownian motion training")[
- The new Flow Matching loss becomes:
$
cal(L)_"CFM" = EE_(t ~ cal(U)(0, 1), z_0 ~ p_"data", z_1 ~ z_0 + cal(N)(0, sigma^2), x ~ p_t (dot, z_1))[ ||u_t^theta (x | z_0) - u_t^"target" (x | z_1) ||^2 ]
$
]

348
train.py

@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
from filelock import FileLock, Timeout from filelock import FileLock, Timeout
import numpy as np
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax.nnx as nnx import flax.nnx as nnx
@ -18,13 +19,19 @@ import optax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1" os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1"
@dataclass @dataclass(frozen=True)
class Config: class Config:
"""Flow/DDM training of a simple distribution.""" """Flow/DDM training of a simple distribution."""
space_dimensions: int = 2 space_dimensions: int = 2
"""The dimensionality of the distribution's space.""" """The dimensionality of the distribution's space."""
brownian_motion: bool = False
"""Enables brownian motion."""
brownian_sigma: float = 0.1
"""The size of the brownian motion."""
num_hidden_layers: int = 4 num_hidden_layers: int = 4
"""Number of hidden layers in the MLP.""" """Number of hidden layers in the MLP."""
@ -111,6 +118,59 @@ def check_dent_constraint(x: jax.Array, d: float) -> jax.Array:
return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2)) return jnp.logical_not(jnp.logical_and(theta < d / 2, theta > -d / 2))
def brownian_motion(key: jax.random.PRNGKey, x0: jax.Array, sigma: float) -> jax.Array:
noise = sigma * jax.random.normal(key=key, shape=x0.shape)
return x0 + noise
@partial(jax.jit, static_argnums=(1,))
def brownian_motion_step(
key: jax.random.PRNGKey,
num_samples: int,
r1: float,
r2: float,
d: float = 0.0,
sigma: float = 0.1,
num_check_pts: int = 8,
) -> (jax.Array, jax.Array):
key, key_init = jax.random.split(key)
x0 = sample_p_data(key_init, num_samples, r1, r2, d)
def points_satisfy(x: jax.Array) -> jax.Array:
return jnp.logical_and(
check_donut_constraint(x, r1, r2), check_dent_constraint(x, d)
)
ts = jnp.linspace(0.0, 1.0, num_check_pts + 2)[1:-1]
def cond_fn(state):
_, _, cs = state
return jnp.any(~cs)
def body_fn(state):
key, x, cs = state
key, _key = jax.random.split(key)
x_prop = brownian_motion(_key, x0, sigma)
ok_end = points_satisfy(x_prop)
diff = x_prop - x
pts = x[:, None, :] + diff[:, None, :] * ts[None, :, None]
ok_interior = jnp.all(
points_satisfy(pts.reshape(-1, pts.shape[-1])).reshape(num_samples, -1),
axis=1,
)
satisfied = ok_end & ok_interior
x = jnp.where(jnp.logical_and(~cs, satisfied)[:, None], x_prop, x)
cs = jnp.logical_or(cs, satisfied)
return key, x, cs
init_cs = jnp.zeros((num_samples,), dtype=bool)
init_state = (key, x0, init_cs)
_, x1, _ = jax.lax.while_loop(cond_fn, body_fn, init_state)
return x0, x1
def array_bool_statistics(x: jax.Array) -> float: def array_bool_statistics(x: jax.Array) -> float:
"""Computes the % of True in a bool array.""" """Computes the % of True in a bool array."""
assert x.dtype == jnp.bool_ assert x.dtype == jnp.bool_
@ -119,7 +179,13 @@ def array_bool_statistics(x: jax.Array) -> float:
def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0): def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0):
p_data_key, p_t_key = jax.random.split(key) key, p_data_key, p_t_key = jax.random.split(key, 3)
if config.brownian_motion:
key, key_b = jax.random.split(key)
z0 = sample_p_data(
key, num_samples=32_768, r1=config.r1, r2=config.r2, d=config.d
)
samples = ode_trajectory( samples = ode_trajectory(
p_t_key, p_t_key,
@ -127,6 +193,7 @@ def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int
num_samples=32_768, num_samples=32_768,
sample_steps=config.sample_steps, sample_steps=config.sample_steps,
space_dimensions=config.space_dimensions, space_dimensions=config.space_dimensions,
cond=z0 if config.brownian_motion else None,
) )
stats_donut = array_bool_statistics( stats_donut = array_bool_statistics(
@ -140,7 +207,8 @@ def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int
plt.figure() plt.figure()
sns.scatterplot(x=samples[:, 0], y=samples[:, 1], size=0.1) sns.scatterplot(x=samples[:, 0], y=samples[:, 1], size=0.1)
save_folder = ( save_folder = (
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}" Path("results")
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}"
) )
save_folder.mkdir(parents=True, exist_ok=True) save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig( plt.savefig(
@ -238,7 +306,7 @@ beta_grad = jax.jit(jax.vmap(beta_scalar_grad))
def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array: def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array:
return x_t + h * model(x_t, t) return x_t[..., :2] + h * model(x_t, t)
@partial(nnx.jit, static_argnums=(2, 3, 4)) @partial(nnx.jit, static_argnums=(2, 3, 4))
@ -248,6 +316,7 @@ def ode_trajectory(
num_samples: int, num_samples: int,
sample_steps: int, sample_steps: int,
space_dimensions: int, space_dimensions: int,
cond: Optional[jax.Array] = None,
) -> jax.Array: ) -> jax.Array:
t = jnp.zeros((num_samples,)) t = jnp.zeros((num_samples,))
h = 1.0 / sample_steps h = 1.0 / sample_steps
@ -255,15 +324,11 @@ def ode_trajectory(
def body(i, state): def body(i, state):
t, x = state t, x = state
x = ode_step(model, x, t, h) x_cond = jnp.concat((x, cond), axis=-1) if cond is not None else x
x = ode_step(model, x_cond, t, h)
return (t + h, x) return (t + h, x)
_, x = jax.lax.fori_loop(0, config.sample_steps, body, (t, x)) _, x = jax.lax.fori_loop(0, config.sample_steps, body, (t, x))
# for i in range(sample_steps):
# x = ode_step(model, x, t, h)
# t = t + h
return x return x
@ -278,20 +343,67 @@ def sde_trajectory(model: MLP) -> jax.Array:
# --- Training ---------------------------------------- # --- Training ----------------------------------------
def main(config: Config): @nnx.jit
rngs = nnx.Rngs(config.seed) def train_step(model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey):
key_e, key_t = jax.random.split(key, 2)
eps = jax.random.normal(key=key_e, shape=z.shape)
t = jax.random.uniform(key=key_t, shape=[z.shape[0]])
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps
def loss_fn(model, z, t, eps):
loss = jnp.sum(
(model(x, t) - (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps))
** 2,
axis=-1,
)
return jnp.mean(loss)
value_grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = value_grad_fn(model, z, t, eps)
optim.update(grads)
return loss
if config.show_p_data:
points = sample_p_data(rngs.params(), 32_768, config.r1, config.r2, config.d)
plt.figure()
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1)
if not os.path.exists("./results"):
os.mkdir("results")
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight")
exit()
@nnx.jit
def train_step_brownian(
model: MLP,
optim: nnx.Optimizer,
z: (jax.Array, jax.Array),
key: jax.random.PRNGKey,
):
z0, z1 = z
assert z0.shape == z1.shape
key_e, key_t = jax.random.split(key)
eps = jax.random.normal(key=key_e, shape=z1.shape)
t = jax.random.uniform(key=key_t, shape=[z1.shape[0]])
x = alpha(t)[:, None] * z1 + beta(t)[:, None] * eps
def loss_fn(model):
x_brownian = jnp.concat((x, z0), axis=-1)
loss = jnp.sum(
(
model(x_brownian, t)
- (alpha_grad(t)[:, None] * beta_grad(t)[:, None] * eps)
)
** 2,
axis=-1,
)
return jnp.mean(loss)
value_grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = value_grad_fn(model)
optim.update(grads)
return loss
def setup_model_and_optimizer(config: Config, rngs: nnx.Rngs) -> (MLP, nnx.Optimizer):
model = MLP( model = MLP(
in_features=config.space_dimensions, in_features=2 * config.space_dimensions
if config.brownian_motion
else config.space_dimensions,
out_features=config.space_dimensions, out_features=config.space_dimensions,
num_hidden_layers=config.num_hidden_layers, num_hidden_layers=config.num_hidden_layers,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
@ -305,51 +417,133 @@ def main(config: Config):
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)), tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)),
) )
@nnx.jit return model, optim
def train_step(
model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey
): def update_experiments_log(i, stats_donut, stats_dent, config: Config) -> None:
key_e, key_t = jax.random.split(key) try:
eps = jax.random.normal(key=key_e, shape=z.shape) with FileLock(os.path.join("results", "experiments.lock")):
t = jax.random.uniform(key=key_t, shape=[z.shape[0]]) file = os.path.join("results", "experiments.csv")
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps if os.path.exists(file):
df = pd.read_csv(file)
def loss_fn(model, z, t, eps): else:
loss = jnp.sum( df = pd.DataFrame(
( columns=[
model(x, t) "num_hidden_layers",
- (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps) "hidden_size",
"mlp_bias",
"fourier_dim",
"r1",
"r2",
"d",
"sample_steps",
"p_donut",
"p_dent",
"step",
"brownian",
]
) )
** 2,
axis=-1, new_row = pd.DataFrame(
{
"num_hidden_layers": [config.num_hidden_layers],
"hidden_size": [config.hidden_size],
"mlp_bias": [config.mlp_bias],
"fourier_dim": [config.fourier_dim],
"r1": [config.r1],
"r2": [config.r2],
"d": [config.d],
"sample_steps": [config.sample_steps],
"p_donut": [stats_donut[-1]],
"p_dent": [stats_dent[-1]],
"step": [i],
"brownian": [config.brownian_motion],
}
) )
return jnp.mean(loss)
value_grad_fn = nnx.value_and_grad(loss_fn) df = pd.concat([df, new_row], ignore_index=True)
loss, grads = value_grad_fn(model, z, t, eps) df.to_csv(file, index=False)
except Timeout:
print("Timeout!!!")
def plot_p_data(key: jax.random.PRNGKey, config: Config):
if not config.brownian_motion:
points = sample_p_data(key, config.batch_size, config.r1, config.r2, config.d)
plt.figure()
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1)
if not os.path.exists("./results"):
os.mkdir("results")
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight")
else:
x0, x1 = brownian_motion_step(
key,
config.batch_size,
config.r1,
config.r2,
config.d,
config.brownian_sigma,
)
optim.update(grads) # If you have x0, x1 as JAX arrays, first convert:
return loss x0_np = np.array(x0) # shape (1024,2)
x1_np = np.array(x1) # shape (1024,2)
cached_train_step = nnx.cached_partial(train_step, model, optim) plt.figure(figsize=(8, 8))
# draw a tiny line for each sample
for start, end in zip(x0_np, x1_np):
plt.plot([start[0], end[0]], [start[1], end[1]], linewidth=0.5, alpha=0.5)
# scatter the start points
plt.scatter(x0_np[:, 0], x0_np[:, 1], s=10, label="x₀")
# scatter the end points
plt.scatter(x1_np[:, 0], x1_np[:, 1], s=10, label="x₁")
plt.legend(loc="upper right")
plt.axis("equal")
plt.title("Brownian step in the dented annulus")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.savefig("results/brownian.png")
def main(config: Config):
rngs = nnx.Rngs(config.seed)
if config.show_p_data:
plot_p_data(rngs.params(), config)
exit()
model, optim = setup_model_and_optimizer(config, rngs)
step_fn = train_step_brownian if config.brownian_motion else train_step
step_fn = nnx.cached_partial(step_fn, model, optim)
sampler = partial(
(
partial(brownian_motion_step, sigma=config.brownian_sigma)
if config.brownian_motion
else sample_p_data
),
num_samples=config.batch_size,
r1=config.r1,
r2=config.r2,
d=config.d,
)
a_donut = (jnp.pi - 0.5 * config.d) * (config.r2**2 - config.r1**2) a_donut = (jnp.pi - 0.5 * config.d) * (config.r2**2 - config.r1**2)
a_dent = 0.5 * config.d * (config.r2**2 - config.r1**2) a_dent = 0.5 * config.d * (config.r2**2 - config.r1**2)
stats_donut = [] stats_donut = []
stats_dent = [] stats_dent = []
for i in tqdm(range(config.num_steps)): for i in tqdm(range(config.num_steps)):
z = sample_p_data( z = sampler(rngs.params())
rngs.params(), _ = step_fn(z, rngs.params())
num_samples=config.batch_size,
r1=config.r1,
r2=config.r2,
d=config.d,
)
_ = cached_train_step(z, rngs.params())
if ( if (
config.evaluate_constraints_every != 0 config.evaluate_constraints_every != 0
and i > 0
and i % config.evaluate_constraints_every == 0 and i % config.evaluate_constraints_every == 0
): ):
stat_donut, stat_dent = evaluate_constraints( stat_donut, stat_dent = evaluate_constraints(
@ -358,48 +552,9 @@ def main(config: Config):
stats_donut.append(stat_donut.item() / a_donut) stats_donut.append(stat_donut.item() / a_donut)
stats_dent.append(stat_dent.item() / a_dent) stats_dent.append(stat_dent.item() / a_dent)
try: update_experiments_log(
with FileLock(os.path.join("results", "experiments.lock")): i=i, stats_donut=stats_donut, stats_dent=stats_dent, config=config
file = os.path.join("results", "experiments.csv") )
if os.path.exists(file):
df = pd.read_csv(file)
else:
df = pd.DataFrame(
columns=[
"num_hidden_layers",
"hidden_size",
"mlp_bias",
"fourier_dim",
"r1",
"r2",
"d",
"sample_steps",
"p_donut",
"p_dent",
"step",
]
)
new_row = pd.DataFrame(
{
"num_hidden_layers": [config.num_hidden_layers],
"hidden_size": [config.hidden_size],
"mlp_bias": [config.mlp_bias],
"fourier_dim": [config.fourier_dim],
"r1": [config.r1],
"r2": [config.r2],
"d": [config.d],
"sample_steps": [config.sample_steps],
"p_donut": [stats_donut[-1]],
"p_dent": [stats_dent[-1]],
"step": [i],
}
)
df = pd.concat([df, new_row], ignore_index=True)
df.to_csv(file, index=False)
except Timeout:
print("Timeout!!!")
# Plot results # Plot results
# 1. Set a “pretty” style # 1. Set a “pretty” style
@ -427,7 +582,8 @@ def main(config: Config):
fig.tight_layout(rect=[0, 0.03, 1, 0.95]) fig.tight_layout(rect=[0, 0.03, 1, 0.95])
save_folder = ( save_folder = (
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}" Path("results")
/ f"mlp{'_b' if config.brownian_motion else ''}_l{config.num_hidden_layers}_h{config.hidden_size}"
) )
save_folder.mkdir(parents=True, exist_ok=True) save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig(os.path.join(save_folder, "constraint_stats.png")) plt.savefig(os.path.join(save_folder, "constraint_stats.png"))
@ -435,5 +591,5 @@ def main(config: Config):
if __name__ == "__main__": if __name__ == "__main__":
config = tyro.cli(Config) config: Config = tyro.cli(Config)
main(config) main(config)

Loading…
Cancel
Save