You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
258 lines
7.0 KiB
258 lines
7.0 KiB
import tyro
|
|
from functools import partial
|
|
from dataclasses import dataclass
|
|
from tqdm import tqdm
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import flax.nnx as nnx
|
|
import optax
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
"""Flow/DDM training of a simple distribution."""
|
|
|
|
space_dimensions: int = 2
|
|
"""The dimensionality of the distribution's space."""
|
|
|
|
num_hidden_layers: int = 4
|
|
"""Number of hidden layers in the MLP."""
|
|
|
|
hidden_size: int = 64
|
|
"""The size of the hidden layers of the MLP."""
|
|
|
|
mlp_bias: bool = True
|
|
"""Enable the bias on every layer of the MLP."""
|
|
|
|
fourier_dim: int = 6
|
|
"""Fourier dimensions. Will be concatenated to the input of the MLP."""
|
|
|
|
fourier_max_period: float = 10_000.0
|
|
"""Range of features of the Fourier features."""
|
|
|
|
num_steps: int = 100_000
|
|
"""How many steps of gradient descent to perform."""
|
|
|
|
batch_size: int = 512
|
|
"""How many samples per mini-batch."""
|
|
|
|
r1: float = 0.3
|
|
"""Inner radius of the donut for p_data"""
|
|
|
|
r2: float = 0.8
|
|
"""Outer radius of the donut for p_data"""
|
|
|
|
sample_steps: int = 100
|
|
"""The number of steps taken during sampling"""
|
|
|
|
seed: int = 42
|
|
"""The seed used for randomness."""
|
|
|
|
|
|
# --- Data generation process ----------------------------
|
|
@partial(jax.jit, static_argnums=(1,))
|
|
def sample_p_data(
|
|
key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float
|
|
) -> jax.Array:
|
|
key_r, key_t = jax.random.split(key)
|
|
|
|
u = jax.random.uniform(key_r, (num_samples,), minval=0.0, maxval=1.0)
|
|
# radius distribution r => CDF(r) = (r^2 - r1^2)/(r^2 - r1^2) => invert:
|
|
r = jnp.sqrt(u * (r2**2 - r1**2) + r1**2)
|
|
|
|
# Sample angle uniformly in [0, 2pi]
|
|
theta = jax.random.uniform(key_t, (num_samples,), minval=0.0, maxval=2 * jnp.pi)
|
|
|
|
# Convert to cartesian
|
|
x = r * jnp.cos(theta)
|
|
y = r * jnp.sin(theta)
|
|
|
|
return jnp.stack((x, y), axis=-1)
|
|
|
|
|
|
# --- Model definition -----------------------------------
|
|
class MLP(nnx.Module):
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
fourier_features: int,
|
|
num_hidden_layers: int,
|
|
hidden_size: int,
|
|
use_bias: bool,
|
|
fourier_max_period: float,
|
|
rngs: nnx.Rngs,
|
|
) -> None:
|
|
self.fourier_dim = fourier_features
|
|
self.fourier_max_period = fourier_max_period
|
|
|
|
network = [
|
|
nnx.Linear(
|
|
in_features=in_features + fourier_features,
|
|
out_features=hidden_size,
|
|
use_bias=use_bias,
|
|
rngs=rngs,
|
|
),
|
|
nnx.silu,
|
|
]
|
|
for _ in range(num_hidden_layers):
|
|
network.append(
|
|
nnx.Linear(
|
|
in_features=hidden_size,
|
|
out_features=hidden_size,
|
|
use_bias=use_bias,
|
|
rngs=rngs,
|
|
)
|
|
)
|
|
network.append(nnx.silu)
|
|
|
|
network.append(
|
|
nnx.Linear(
|
|
in_features=hidden_size,
|
|
out_features=out_features,
|
|
use_bias=use_bias,
|
|
rngs=rngs,
|
|
)
|
|
)
|
|
|
|
self.network = nnx.Sequential(*network)
|
|
|
|
def time_embed(
|
|
self, t: jax.Array, embed_dim: int, max_period: float = 10_000.0
|
|
) -> jax.Array:
|
|
assert embed_dim % 2 == 0, "embed_dim must be even"
|
|
|
|
if t.shape[-1] != 1:
|
|
t = t[..., None]
|
|
|
|
half_dim = embed_dim // 2
|
|
freqs = jnp.exp(
|
|
-jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) / half_dim
|
|
)
|
|
args = t * freqs
|
|
time_features = jnp.concatenate([jnp.sin(args), jnp.cos(args)], axis=-1)
|
|
return time_features
|
|
|
|
def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array:
|
|
t_encoded = self.time_embed(t, self.fourier_dim, self.fourier_max_period)
|
|
x = jnp.concatenate((x, t_encoded), axis=-1)
|
|
return self.network(x)
|
|
|
|
|
|
# --- Diffusion functions -----------------------------
|
|
def alpha(t: jax.Array) -> jax.Array:
|
|
return jnp.clip(t, 0.0, 1.0)
|
|
|
|
|
|
alpha_scalar_grad = jax.grad(alpha)
|
|
alpha_grad = jax.jit(jax.vmap(alpha_scalar_grad))
|
|
|
|
|
|
def beta(t: jax.Array) -> jax.Array:
|
|
return 1.0 - jnp.clip(t, 0.0, 1.0)
|
|
|
|
|
|
beta_scalar_grad = jax.grad(beta)
|
|
beta_grad = jax.jit(jax.vmap(beta_scalar_grad))
|
|
|
|
|
|
def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array:
|
|
return x_t + h * model(x_t, t)
|
|
|
|
|
|
def ode_trajectory(
|
|
key: jax.random.PRNGKey, model: MLP, num_samples: int, config: Config
|
|
) -> jax.Array:
|
|
t = jnp.zeros((num_samples,))
|
|
h = 1.0 / config.sample_steps
|
|
x = jax.random.normal(key=key, shape=(num_samples, config.space_dimensions))
|
|
|
|
for i in range(config.sample_steps):
|
|
x = ode_step(model, x, t, h)
|
|
t = t + h
|
|
|
|
return x
|
|
|
|
|
|
def sde_step(model: MLP, x_t: jax.Array, t: jax.Array, sigma_t: jax.Array) -> jax.Array:
|
|
pass
|
|
|
|
|
|
def sde_trajectory(model: MLP) -> jax.Array:
|
|
pass
|
|
|
|
|
|
# --- Training ----------------------------------------
|
|
|
|
|
|
def main(config: Config):
|
|
rngs = nnx.Rngs(config.seed)
|
|
|
|
model = MLP(
|
|
in_features=config.space_dimensions,
|
|
out_features=config.space_dimensions,
|
|
num_hidden_layers=config.num_hidden_layers,
|
|
hidden_size=config.hidden_size,
|
|
use_bias=config.mlp_bias,
|
|
fourier_features=config.fourier_dim,
|
|
fourier_max_period=config.fourier_max_period,
|
|
rngs=rngs,
|
|
)
|
|
optim = nnx.Optimizer(
|
|
model,
|
|
tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)),
|
|
)
|
|
|
|
@nnx.jit
|
|
def train_step(
|
|
model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey
|
|
):
|
|
key_e, key_t = jax.random.split(key)
|
|
eps = jax.random.normal(key=key_e, shape=z.shape)
|
|
t = jax.random.uniform(key=key_t, shape=[z.shape[0]])
|
|
x = alpha(t)[:, None] * z + beta(t)[:, None] * eps
|
|
|
|
def loss_fn(model, z, t, eps):
|
|
loss = jnp.sum(
|
|
(
|
|
model(x, t)
|
|
- (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps)
|
|
)
|
|
** 2,
|
|
axis=-1,
|
|
)
|
|
return jnp.mean(loss)
|
|
|
|
value_grad_fn = nnx.value_and_grad(loss_fn)
|
|
loss, grads = value_grad_fn(model, z, t, eps)
|
|
|
|
optim.update(grads)
|
|
return loss
|
|
|
|
cached_train_step = nnx.cached_partial(train_step, model, optim)
|
|
|
|
for i in tqdm(range(config.num_steps)):
|
|
z = sample_p_data(
|
|
rngs.params(), num_samples=config.batch_size, r1=config.r1, r2=config.r2
|
|
)
|
|
|
|
_ = cached_train_step(z, rngs.params())
|
|
|
|
# Generate samples
|
|
|
|
print("sampling...", end="", flush=True)
|
|
samples = ode_trajectory(
|
|
key=rngs.params(), model=model, num_samples=1024, config=config
|
|
)
|
|
print(" done!")
|
|
# samples = np.array(z)
|
|
sns.scatterplot(x=samples[:, 0], y=samples[:, 1])
|
|
plt.savefig("scatter.png", dpi=300, bbox_inches="tight")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config = tyro.cli(Config)
|
|
main(config)
|
|
|