Browse Source

First commit

master
CALVO GONZALEZ Ramon 7 months ago
commit
88bdf263d1
  1. 10
      .gitignore
  2. 1
      .python-version
  3. 18
      pyproject.toml
  4. 258
      train.py
  5. 1762
      uv.lock

10
.gitignore

@ -0,0 +1,10 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv

1
.python-version

@ -0,0 +1 @@
3.11

18
pyproject.toml

@ -0,0 +1,18 @@
[project]
name = "diffusion-points"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"distrax>=0.1.5",
"einops>=0.8.1",
"flax>=0.10.6",
"jax[cuda12]>=0.6.0",
"numpy>=2.2.5",
"orbax>=0.1.9",
"seaborn>=0.13.2",
"tqdm>=4.67.1",
"tyro>=0.9.19",
"wandb>=0.19.10",
]

258
train.py

@ -0,0 +1,258 @@
import tyro
from functools import partial
from dataclasses import dataclass
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
@dataclass
class Config:
"""Flow/DDM training of a simple distribution."""
space_dimensions: int = 2
"""The dimensionality of the distribution's space."""
num_hidden_layers: int = 4
"""Number of hidden layers in the MLP."""
hidden_size: int = 64
"""The size of the hidden layers of the MLP."""
mlp_bias: bool = True
"""Enable the bias on every layer of the MLP."""
fourier_dim: int = 6
"""Fourier dimensions. Will be concatenated to the input of the MLP."""
fourier_max_period: float = 10_000.0
"""Range of features of the Fourier features."""
num_steps: int = 100_000
"""How many steps of gradient descent to perform."""
batch_size: int = 512
"""How many samples per mini-batch."""
r1: float = 0.3
"""Inner radius of the donut for p_data"""
r2: float = 0.8
"""Outer radius of the donut for p_data"""
sample_steps: int = 100
"""The number of steps taken during sampling"""
seed: int = 42
"""The seed used for randomness."""
# --- Data generation process ----------------------------
@partial(jax.jit, static_argnums=(1,))
def sample_p_data(
key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float
) -> jax.Array:
key_r, key_t = jax.random.split(key)
u = jax.random.uniform(key_r, (num_samples,), minval=0.0, maxval=1.0)
# radius distribution r => CDF(r) = (r^2 - r1^2)/(r^2 - r1^2) => invert:
r = jnp.sqrt(u * (r2**2 - r1**2) + r1**2)
# Sample angle uniformly in [0, 2pi]
theta = jax.random.uniform(key_t, (num_samples,), minval=0.0, maxval=2 * jnp.pi)
# Convert to cartesian
x = r * jnp.cos(theta)
y = r * jnp.sin(theta)
return jnp.stack((x, y), axis=-1)
# --- Model definition -----------------------------------
class MLP(nnx.Module):
def __init__(
self,
in_features: int,
out_features: int,
fourier_features: int,
num_hidden_layers: int,
hidden_size: int,
use_bias: bool,
fourier_max_period: float,
rngs: nnx.Rngs,
) -> None:
self.fourier_dim = fourier_features
self.fourier_max_period = fourier_max_period
network = [
nnx.Linear(
in_features=in_features + fourier_features,
out_features=hidden_size,
use_bias=use_bias,
rngs=rngs,
),
nnx.silu,
]
for _ in range(num_hidden_layers):
network.append(
nnx.Linear(
in_features=hidden_size,
out_features=hidden_size,
use_bias=use_bias,
rngs=rngs,
)
)
network.append(nnx.silu)
network.append(
nnx.Linear(
in_features=hidden_size,
out_features=out_features,
use_bias=use_bias,
rngs=rngs,
)
)
self.network = nnx.Sequential(*network)
def time_embed(
self, t: jax.Array, embed_dim: int, max_period: float = 10_000.0
) -> jax.Array:
assert embed_dim % 2 == 0, "embed_dim must be even"
if t.shape[-1] != 1:
t = t[..., None]
half_dim = embed_dim // 2
freqs = jnp.exp(
-jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) / half_dim
)
args = t * freqs
time_features = jnp.concatenate([jnp.sin(args), jnp.cos(args)], axis=-1)
return time_features
def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array:
t_encoded = self.time_embed(t, self.fourier_dim, self.fourier_max_period)
x = jnp.concatenate((x, t_encoded), axis=-1)
return self.network(x)
# --- Diffusion functions -----------------------------
def alpha(t: jax.Array) -> jax.Array:
return jnp.clip(t, 0.0, 1.0)
alpha_scalar_grad = jax.grad(alpha)
alpha_grad = jax.jit(jax.vmap(alpha_scalar_grad))
def beta(t: jax.Array) -> jax.Array:
return 1.0 - jnp.clip(t, 0.0, 1.0)
beta_scalar_grad = jax.grad(beta)
beta_grad = jax.jit(jax.vmap(beta_scalar_grad))
def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array:
return x_t + h * model(x_t, t)
def ode_trajectory(
key: jax.random.PRNGKey, model: MLP, num_samples: int, config: Config
) -> jax.Array:
t = jnp.zeros((num_samples,))
h = 1.0 / config.sample_steps
x = jax.random.normal(key=key, shape=(num_samples, config.space_dimensions))
for i in range(config.sample_steps):
x = ode_step(model, x, t, h)
t = t + h
return x
def sde_step(model: MLP, x_t: jax.Array, t: jax.Array, sigma_t: jax.Array) -> jax.Array:
pass
def sde_trajectory(model: MLP) -> jax.Array:
pass
# --- Training ----------------------------------------
def main(config: Config):
rngs = nnx.Rngs(config.seed)
model = MLP(
in_features=config.space_dimensions,
out_features=config.space_dimensions,
num_hidden_layers=config.num_hidden_layers,
hidden_size=config.hidden_size,
use_bias=config.mlp_bias,
fourier_features=config.fourier_dim,
fourier_max_period=config.fourier_max_period,
rngs=rngs,
)
optim = nnx.Optimizer(
model,
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)),
)
@nnx.jit
def train_step(
model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey
):
key_e, key_t = jax.random.split(key)
eps = jax.random.normal(key=key_e, shape=z.shape)
t = jax.random.uniform(key=key_t, shape=[z.shape[0]])
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps
def loss_fn(model, z, t, eps):
loss = jnp.sum(
(
model(x, t)
- (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps)
)
** 2,
axis=-1,
)
return jnp.mean(loss)
value_grad_fn = nnx.value_and_grad(loss_fn)
loss, grads = value_grad_fn(model, z, t, eps)
optim.update(grads)
return loss
cached_train_step = nnx.cached_partial(train_step, model, optim)
for i in tqdm(range(config.num_steps)):
z = sample_p_data(
rngs.params(), num_samples=config.batch_size, r1=config.r1, r2=config.r2
)
_ = cached_train_step(z, rngs.params())
# Generate samples
print("sampling...", end="", flush=True)
samples = ode_trajectory(
key=rngs.params(), model=model, num_samples=1024, config=config
)
print(" done!")
# samples = np.array(z)
sns.scatterplot(x=samples[:, 0], y=samples[:, 1])
plt.savefig("scatter.png", dpi=300, bbox_inches="tight")
if __name__ == "__main__":
config = tyro.cli(Config)
main(config)

1762
uv.lock

File diff suppressed because it is too large
Loading…
Cancel
Save