commit
88bdf263d1
5 changed files with 2049 additions and 0 deletions
@ -0,0 +1,10 @@ |
|||||
|
# Python-generated files |
||||
|
__pycache__/ |
||||
|
*.py[oc] |
||||
|
build/ |
||||
|
dist/ |
||||
|
wheels/ |
||||
|
*.egg-info |
||||
|
|
||||
|
# Virtual environments |
||||
|
.venv |
||||
@ -0,0 +1 @@ |
|||||
|
3.11 |
||||
@ -0,0 +1,18 @@ |
|||||
|
[project] |
||||
|
name = "diffusion-points" |
||||
|
version = "0.1.0" |
||||
|
description = "Add your description here" |
||||
|
readme = "README.md" |
||||
|
requires-python = ">=3.11" |
||||
|
dependencies = [ |
||||
|
"distrax>=0.1.5", |
||||
|
"einops>=0.8.1", |
||||
|
"flax>=0.10.6", |
||||
|
"jax[cuda12]>=0.6.0", |
||||
|
"numpy>=2.2.5", |
||||
|
"orbax>=0.1.9", |
||||
|
"seaborn>=0.13.2", |
||||
|
"tqdm>=4.67.1", |
||||
|
"tyro>=0.9.19", |
||||
|
"wandb>=0.19.10", |
||||
|
] |
||||
@ -0,0 +1,258 @@ |
|||||
|
import tyro |
||||
|
from functools import partial |
||||
|
from dataclasses import dataclass |
||||
|
from tqdm import tqdm |
||||
|
import seaborn as sns |
||||
|
import matplotlib.pyplot as plt |
||||
|
|
||||
|
import jax |
||||
|
import jax.numpy as jnp |
||||
|
import flax.nnx as nnx |
||||
|
import optax |
||||
|
|
||||
|
|
||||
|
@dataclass |
||||
|
class Config: |
||||
|
"""Flow/DDM training of a simple distribution.""" |
||||
|
|
||||
|
space_dimensions: int = 2 |
||||
|
"""The dimensionality of the distribution's space.""" |
||||
|
|
||||
|
num_hidden_layers: int = 4 |
||||
|
"""Number of hidden layers in the MLP.""" |
||||
|
|
||||
|
hidden_size: int = 64 |
||||
|
"""The size of the hidden layers of the MLP.""" |
||||
|
|
||||
|
mlp_bias: bool = True |
||||
|
"""Enable the bias on every layer of the MLP.""" |
||||
|
|
||||
|
fourier_dim: int = 6 |
||||
|
"""Fourier dimensions. Will be concatenated to the input of the MLP.""" |
||||
|
|
||||
|
fourier_max_period: float = 10_000.0 |
||||
|
"""Range of features of the Fourier features.""" |
||||
|
|
||||
|
num_steps: int = 100_000 |
||||
|
"""How many steps of gradient descent to perform.""" |
||||
|
|
||||
|
batch_size: int = 512 |
||||
|
"""How many samples per mini-batch.""" |
||||
|
|
||||
|
r1: float = 0.3 |
||||
|
"""Inner radius of the donut for p_data""" |
||||
|
|
||||
|
r2: float = 0.8 |
||||
|
"""Outer radius of the donut for p_data""" |
||||
|
|
||||
|
sample_steps: int = 100 |
||||
|
"""The number of steps taken during sampling""" |
||||
|
|
||||
|
seed: int = 42 |
||||
|
"""The seed used for randomness.""" |
||||
|
|
||||
|
|
||||
|
# --- Data generation process ---------------------------- |
||||
|
@partial(jax.jit, static_argnums=(1,)) |
||||
|
def sample_p_data( |
||||
|
key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float |
||||
|
) -> jax.Array: |
||||
|
key_r, key_t = jax.random.split(key) |
||||
|
|
||||
|
u = jax.random.uniform(key_r, (num_samples,), minval=0.0, maxval=1.0) |
||||
|
# radius distribution r => CDF(r) = (r^2 - r1^2)/(r^2 - r1^2) => invert: |
||||
|
r = jnp.sqrt(u * (r2**2 - r1**2) + r1**2) |
||||
|
|
||||
|
# Sample angle uniformly in [0, 2pi] |
||||
|
theta = jax.random.uniform(key_t, (num_samples,), minval=0.0, maxval=2 * jnp.pi) |
||||
|
|
||||
|
# Convert to cartesian |
||||
|
x = r * jnp.cos(theta) |
||||
|
y = r * jnp.sin(theta) |
||||
|
|
||||
|
return jnp.stack((x, y), axis=-1) |
||||
|
|
||||
|
|
||||
|
# --- Model definition ----------------------------------- |
||||
|
class MLP(nnx.Module): |
||||
|
def __init__( |
||||
|
self, |
||||
|
in_features: int, |
||||
|
out_features: int, |
||||
|
fourier_features: int, |
||||
|
num_hidden_layers: int, |
||||
|
hidden_size: int, |
||||
|
use_bias: bool, |
||||
|
fourier_max_period: float, |
||||
|
rngs: nnx.Rngs, |
||||
|
) -> None: |
||||
|
self.fourier_dim = fourier_features |
||||
|
self.fourier_max_period = fourier_max_period |
||||
|
|
||||
|
network = [ |
||||
|
nnx.Linear( |
||||
|
in_features=in_features + fourier_features, |
||||
|
out_features=hidden_size, |
||||
|
use_bias=use_bias, |
||||
|
rngs=rngs, |
||||
|
), |
||||
|
nnx.silu, |
||||
|
] |
||||
|
for _ in range(num_hidden_layers): |
||||
|
network.append( |
||||
|
nnx.Linear( |
||||
|
in_features=hidden_size, |
||||
|
out_features=hidden_size, |
||||
|
use_bias=use_bias, |
||||
|
rngs=rngs, |
||||
|
) |
||||
|
) |
||||
|
network.append(nnx.silu) |
||||
|
|
||||
|
network.append( |
||||
|
nnx.Linear( |
||||
|
in_features=hidden_size, |
||||
|
out_features=out_features, |
||||
|
use_bias=use_bias, |
||||
|
rngs=rngs, |
||||
|
) |
||||
|
) |
||||
|
|
||||
|
self.network = nnx.Sequential(*network) |
||||
|
|
||||
|
def time_embed( |
||||
|
self, t: jax.Array, embed_dim: int, max_period: float = 10_000.0 |
||||
|
) -> jax.Array: |
||||
|
assert embed_dim % 2 == 0, "embed_dim must be even" |
||||
|
|
||||
|
if t.shape[-1] != 1: |
||||
|
t = t[..., None] |
||||
|
|
||||
|
half_dim = embed_dim // 2 |
||||
|
freqs = jnp.exp( |
||||
|
-jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) / half_dim |
||||
|
) |
||||
|
args = t * freqs |
||||
|
time_features = jnp.concatenate([jnp.sin(args), jnp.cos(args)], axis=-1) |
||||
|
return time_features |
||||
|
|
||||
|
def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array: |
||||
|
t_encoded = self.time_embed(t, self.fourier_dim, self.fourier_max_period) |
||||
|
x = jnp.concatenate((x, t_encoded), axis=-1) |
||||
|
return self.network(x) |
||||
|
|
||||
|
|
||||
|
# --- Diffusion functions ----------------------------- |
||||
|
def alpha(t: jax.Array) -> jax.Array: |
||||
|
return jnp.clip(t, 0.0, 1.0) |
||||
|
|
||||
|
|
||||
|
alpha_scalar_grad = jax.grad(alpha) |
||||
|
alpha_grad = jax.jit(jax.vmap(alpha_scalar_grad)) |
||||
|
|
||||
|
|
||||
|
def beta(t: jax.Array) -> jax.Array: |
||||
|
return 1.0 - jnp.clip(t, 0.0, 1.0) |
||||
|
|
||||
|
|
||||
|
beta_scalar_grad = jax.grad(beta) |
||||
|
beta_grad = jax.jit(jax.vmap(beta_scalar_grad)) |
||||
|
|
||||
|
|
||||
|
def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array: |
||||
|
return x_t + h * model(x_t, t) |
||||
|
|
||||
|
|
||||
|
def ode_trajectory( |
||||
|
key: jax.random.PRNGKey, model: MLP, num_samples: int, config: Config |
||||
|
) -> jax.Array: |
||||
|
t = jnp.zeros((num_samples,)) |
||||
|
h = 1.0 / config.sample_steps |
||||
|
x = jax.random.normal(key=key, shape=(num_samples, config.space_dimensions)) |
||||
|
|
||||
|
for i in range(config.sample_steps): |
||||
|
x = ode_step(model, x, t, h) |
||||
|
t = t + h |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
|
||||
|
def sde_step(model: MLP, x_t: jax.Array, t: jax.Array, sigma_t: jax.Array) -> jax.Array: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
def sde_trajectory(model: MLP) -> jax.Array: |
||||
|
pass |
||||
|
|
||||
|
|
||||
|
# --- Training ---------------------------------------- |
||||
|
|
||||
|
|
||||
|
def main(config: Config): |
||||
|
rngs = nnx.Rngs(config.seed) |
||||
|
|
||||
|
model = MLP( |
||||
|
in_features=config.space_dimensions, |
||||
|
out_features=config.space_dimensions, |
||||
|
num_hidden_layers=config.num_hidden_layers, |
||||
|
hidden_size=config.hidden_size, |
||||
|
use_bias=config.mlp_bias, |
||||
|
fourier_features=config.fourier_dim, |
||||
|
fourier_max_period=config.fourier_max_period, |
||||
|
rngs=rngs, |
||||
|
) |
||||
|
optim = nnx.Optimizer( |
||||
|
model, |
||||
|
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)), |
||||
|
) |
||||
|
|
||||
|
@nnx.jit |
||||
|
def train_step( |
||||
|
model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey |
||||
|
): |
||||
|
key_e, key_t = jax.random.split(key) |
||||
|
eps = jax.random.normal(key=key_e, shape=z.shape) |
||||
|
t = jax.random.uniform(key=key_t, shape=[z.shape[0]]) |
||||
|
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps |
||||
|
|
||||
|
def loss_fn(model, z, t, eps): |
||||
|
loss = jnp.sum( |
||||
|
( |
||||
|
model(x, t) |
||||
|
- (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps) |
||||
|
) |
||||
|
** 2, |
||||
|
axis=-1, |
||||
|
) |
||||
|
return jnp.mean(loss) |
||||
|
|
||||
|
value_grad_fn = nnx.value_and_grad(loss_fn) |
||||
|
loss, grads = value_grad_fn(model, z, t, eps) |
||||
|
|
||||
|
optim.update(grads) |
||||
|
return loss |
||||
|
|
||||
|
cached_train_step = nnx.cached_partial(train_step, model, optim) |
||||
|
|
||||
|
for i in tqdm(range(config.num_steps)): |
||||
|
z = sample_p_data( |
||||
|
rngs.params(), num_samples=config.batch_size, r1=config.r1, r2=config.r2 |
||||
|
) |
||||
|
|
||||
|
_ = cached_train_step(z, rngs.params()) |
||||
|
|
||||
|
# Generate samples |
||||
|
|
||||
|
print("sampling...", end="", flush=True) |
||||
|
samples = ode_trajectory( |
||||
|
key=rngs.params(), model=model, num_samples=1024, config=config |
||||
|
) |
||||
|
print(" done!") |
||||
|
# samples = np.array(z) |
||||
|
sns.scatterplot(x=samples[:, 0], y=samples[:, 1]) |
||||
|
plt.savefig("scatter.png", dpi=300, bbox_inches="tight") |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
config = tyro.cli(Config) |
||||
|
main(config) |
||||
File diff suppressed because it is too large
Loading…
Reference in new issue