commit
88bdf263d1
5 changed files with 2049 additions and 0 deletions
@ -0,0 +1,10 @@ |
|||
# Python-generated files |
|||
__pycache__/ |
|||
*.py[oc] |
|||
build/ |
|||
dist/ |
|||
wheels/ |
|||
*.egg-info |
|||
|
|||
# Virtual environments |
|||
.venv |
|||
@ -0,0 +1 @@ |
|||
3.11 |
|||
@ -0,0 +1,18 @@ |
|||
[project] |
|||
name = "diffusion-points" |
|||
version = "0.1.0" |
|||
description = "Add your description here" |
|||
readme = "README.md" |
|||
requires-python = ">=3.11" |
|||
dependencies = [ |
|||
"distrax>=0.1.5", |
|||
"einops>=0.8.1", |
|||
"flax>=0.10.6", |
|||
"jax[cuda12]>=0.6.0", |
|||
"numpy>=2.2.5", |
|||
"orbax>=0.1.9", |
|||
"seaborn>=0.13.2", |
|||
"tqdm>=4.67.1", |
|||
"tyro>=0.9.19", |
|||
"wandb>=0.19.10", |
|||
] |
|||
@ -0,0 +1,258 @@ |
|||
import tyro |
|||
from functools import partial |
|||
from dataclasses import dataclass |
|||
from tqdm import tqdm |
|||
import seaborn as sns |
|||
import matplotlib.pyplot as plt |
|||
|
|||
import jax |
|||
import jax.numpy as jnp |
|||
import flax.nnx as nnx |
|||
import optax |
|||
|
|||
|
|||
@dataclass |
|||
class Config: |
|||
"""Flow/DDM training of a simple distribution.""" |
|||
|
|||
space_dimensions: int = 2 |
|||
"""The dimensionality of the distribution's space.""" |
|||
|
|||
num_hidden_layers: int = 4 |
|||
"""Number of hidden layers in the MLP.""" |
|||
|
|||
hidden_size: int = 64 |
|||
"""The size of the hidden layers of the MLP.""" |
|||
|
|||
mlp_bias: bool = True |
|||
"""Enable the bias on every layer of the MLP.""" |
|||
|
|||
fourier_dim: int = 6 |
|||
"""Fourier dimensions. Will be concatenated to the input of the MLP.""" |
|||
|
|||
fourier_max_period: float = 10_000.0 |
|||
"""Range of features of the Fourier features.""" |
|||
|
|||
num_steps: int = 100_000 |
|||
"""How many steps of gradient descent to perform.""" |
|||
|
|||
batch_size: int = 512 |
|||
"""How many samples per mini-batch.""" |
|||
|
|||
r1: float = 0.3 |
|||
"""Inner radius of the donut for p_data""" |
|||
|
|||
r2: float = 0.8 |
|||
"""Outer radius of the donut for p_data""" |
|||
|
|||
sample_steps: int = 100 |
|||
"""The number of steps taken during sampling""" |
|||
|
|||
seed: int = 42 |
|||
"""The seed used for randomness.""" |
|||
|
|||
|
|||
# --- Data generation process ---------------------------- |
|||
@partial(jax.jit, static_argnums=(1,)) |
|||
def sample_p_data( |
|||
key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float |
|||
) -> jax.Array: |
|||
key_r, key_t = jax.random.split(key) |
|||
|
|||
u = jax.random.uniform(key_r, (num_samples,), minval=0.0, maxval=1.0) |
|||
# radius distribution r => CDF(r) = (r^2 - r1^2)/(r^2 - r1^2) => invert: |
|||
r = jnp.sqrt(u * (r2**2 - r1**2) + r1**2) |
|||
|
|||
# Sample angle uniformly in [0, 2pi] |
|||
theta = jax.random.uniform(key_t, (num_samples,), minval=0.0, maxval=2 * jnp.pi) |
|||
|
|||
# Convert to cartesian |
|||
x = r * jnp.cos(theta) |
|||
y = r * jnp.sin(theta) |
|||
|
|||
return jnp.stack((x, y), axis=-1) |
|||
|
|||
|
|||
# --- Model definition ----------------------------------- |
|||
class MLP(nnx.Module): |
|||
def __init__( |
|||
self, |
|||
in_features: int, |
|||
out_features: int, |
|||
fourier_features: int, |
|||
num_hidden_layers: int, |
|||
hidden_size: int, |
|||
use_bias: bool, |
|||
fourier_max_period: float, |
|||
rngs: nnx.Rngs, |
|||
) -> None: |
|||
self.fourier_dim = fourier_features |
|||
self.fourier_max_period = fourier_max_period |
|||
|
|||
network = [ |
|||
nnx.Linear( |
|||
in_features=in_features + fourier_features, |
|||
out_features=hidden_size, |
|||
use_bias=use_bias, |
|||
rngs=rngs, |
|||
), |
|||
nnx.silu, |
|||
] |
|||
for _ in range(num_hidden_layers): |
|||
network.append( |
|||
nnx.Linear( |
|||
in_features=hidden_size, |
|||
out_features=hidden_size, |
|||
use_bias=use_bias, |
|||
rngs=rngs, |
|||
) |
|||
) |
|||
network.append(nnx.silu) |
|||
|
|||
network.append( |
|||
nnx.Linear( |
|||
in_features=hidden_size, |
|||
out_features=out_features, |
|||
use_bias=use_bias, |
|||
rngs=rngs, |
|||
) |
|||
) |
|||
|
|||
self.network = nnx.Sequential(*network) |
|||
|
|||
def time_embed( |
|||
self, t: jax.Array, embed_dim: int, max_period: float = 10_000.0 |
|||
) -> jax.Array: |
|||
assert embed_dim % 2 == 0, "embed_dim must be even" |
|||
|
|||
if t.shape[-1] != 1: |
|||
t = t[..., None] |
|||
|
|||
half_dim = embed_dim // 2 |
|||
freqs = jnp.exp( |
|||
-jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) / half_dim |
|||
) |
|||
args = t * freqs |
|||
time_features = jnp.concatenate([jnp.sin(args), jnp.cos(args)], axis=-1) |
|||
return time_features |
|||
|
|||
def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array: |
|||
t_encoded = self.time_embed(t, self.fourier_dim, self.fourier_max_period) |
|||
x = jnp.concatenate((x, t_encoded), axis=-1) |
|||
return self.network(x) |
|||
|
|||
|
|||
# --- Diffusion functions ----------------------------- |
|||
def alpha(t: jax.Array) -> jax.Array: |
|||
return jnp.clip(t, 0.0, 1.0) |
|||
|
|||
|
|||
alpha_scalar_grad = jax.grad(alpha) |
|||
alpha_grad = jax.jit(jax.vmap(alpha_scalar_grad)) |
|||
|
|||
|
|||
def beta(t: jax.Array) -> jax.Array: |
|||
return 1.0 - jnp.clip(t, 0.0, 1.0) |
|||
|
|||
|
|||
beta_scalar_grad = jax.grad(beta) |
|||
beta_grad = jax.jit(jax.vmap(beta_scalar_grad)) |
|||
|
|||
|
|||
def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array: |
|||
return x_t + h * model(x_t, t) |
|||
|
|||
|
|||
def ode_trajectory( |
|||
key: jax.random.PRNGKey, model: MLP, num_samples: int, config: Config |
|||
) -> jax.Array: |
|||
t = jnp.zeros((num_samples,)) |
|||
h = 1.0 / config.sample_steps |
|||
x = jax.random.normal(key=key, shape=(num_samples, config.space_dimensions)) |
|||
|
|||
for i in range(config.sample_steps): |
|||
x = ode_step(model, x, t, h) |
|||
t = t + h |
|||
|
|||
return x |
|||
|
|||
|
|||
def sde_step(model: MLP, x_t: jax.Array, t: jax.Array, sigma_t: jax.Array) -> jax.Array: |
|||
pass |
|||
|
|||
|
|||
def sde_trajectory(model: MLP) -> jax.Array: |
|||
pass |
|||
|
|||
|
|||
# --- Training ---------------------------------------- |
|||
|
|||
|
|||
def main(config: Config): |
|||
rngs = nnx.Rngs(config.seed) |
|||
|
|||
model = MLP( |
|||
in_features=config.space_dimensions, |
|||
out_features=config.space_dimensions, |
|||
num_hidden_layers=config.num_hidden_layers, |
|||
hidden_size=config.hidden_size, |
|||
use_bias=config.mlp_bias, |
|||
fourier_features=config.fourier_dim, |
|||
fourier_max_period=config.fourier_max_period, |
|||
rngs=rngs, |
|||
) |
|||
optim = nnx.Optimizer( |
|||
model, |
|||
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)), |
|||
) |
|||
|
|||
@nnx.jit |
|||
def train_step( |
|||
model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey |
|||
): |
|||
key_e, key_t = jax.random.split(key) |
|||
eps = jax.random.normal(key=key_e, shape=z.shape) |
|||
t = jax.random.uniform(key=key_t, shape=[z.shape[0]]) |
|||
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps |
|||
|
|||
def loss_fn(model, z, t, eps): |
|||
loss = jnp.sum( |
|||
( |
|||
model(x, t) |
|||
- (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps) |
|||
) |
|||
** 2, |
|||
axis=-1, |
|||
) |
|||
return jnp.mean(loss) |
|||
|
|||
value_grad_fn = nnx.value_and_grad(loss_fn) |
|||
loss, grads = value_grad_fn(model, z, t, eps) |
|||
|
|||
optim.update(grads) |
|||
return loss |
|||
|
|||
cached_train_step = nnx.cached_partial(train_step, model, optim) |
|||
|
|||
for i in tqdm(range(config.num_steps)): |
|||
z = sample_p_data( |
|||
rngs.params(), num_samples=config.batch_size, r1=config.r1, r2=config.r2 |
|||
) |
|||
|
|||
_ = cached_train_step(z, rngs.params()) |
|||
|
|||
# Generate samples |
|||
|
|||
print("sampling...", end="", flush=True) |
|||
samples = ode_trajectory( |
|||
key=rngs.params(), model=model, num_samples=1024, config=config |
|||
) |
|||
print(" done!") |
|||
# samples = np.array(z) |
|||
sns.scatterplot(x=samples[:, 0], y=samples[:, 1]) |
|||
plt.savefig("scatter.png", dpi=300, bbox_inches="tight") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
config = tyro.cli(Config) |
|||
main(config) |
|||
File diff suppressed because it is too large
Loading…
Reference in new issue