import tyro from functools import partial from dataclasses import dataclass from tqdm import tqdm import seaborn as sns import matplotlib.pyplot as plt import jax import jax.numpy as jnp import flax.nnx as nnx import optax @dataclass class Config: """Flow/DDM training of a simple distribution.""" space_dimensions: int = 2 """The dimensionality of the distribution's space.""" num_hidden_layers: int = 4 """Number of hidden layers in the MLP.""" hidden_size: int = 64 """The size of the hidden layers of the MLP.""" mlp_bias: bool = True """Enable the bias on every layer of the MLP.""" fourier_dim: int = 6 """Fourier dimensions. Will be concatenated to the input of the MLP.""" fourier_max_period: float = 10_000.0 """Range of features of the Fourier features.""" num_steps: int = 100_000 """How many steps of gradient descent to perform.""" batch_size: int = 512 """How many samples per mini-batch.""" r1: float = 0.3 """Inner radius of the donut for p_data""" r2: float = 0.8 """Outer radius of the donut for p_data""" sample_steps: int = 100 """The number of steps taken during sampling""" seed: int = 42 """The seed used for randomness.""" # --- Data generation process ---------------------------- @partial(jax.jit, static_argnums=(1,)) def sample_p_data( key: jax.random.PRNGKey, num_samples: int, r1: float, r2: float ) -> jax.Array: key_r, key_t = jax.random.split(key) u = jax.random.uniform(key_r, (num_samples,), minval=0.0, maxval=1.0) # radius distribution r => CDF(r) = (r^2 - r1^2)/(r^2 - r1^2) => invert: r = jnp.sqrt(u * (r2**2 - r1**2) + r1**2) # Sample angle uniformly in [0, 2pi] theta = jax.random.uniform(key_t, (num_samples,), minval=0.0, maxval=2 * jnp.pi) # Convert to cartesian x = r * jnp.cos(theta) y = r * jnp.sin(theta) return jnp.stack((x, y), axis=-1) # --- Model definition ----------------------------------- class MLP(nnx.Module): def __init__( self, in_features: int, out_features: int, fourier_features: int, num_hidden_layers: int, hidden_size: int, use_bias: bool, fourier_max_period: float, rngs: nnx.Rngs, ) -> None: self.fourier_dim = fourier_features self.fourier_max_period = fourier_max_period network = [ nnx.Linear( in_features=in_features + fourier_features, out_features=hidden_size, use_bias=use_bias, rngs=rngs, ), nnx.silu, ] for _ in range(num_hidden_layers): network.append( nnx.Linear( in_features=hidden_size, out_features=hidden_size, use_bias=use_bias, rngs=rngs, ) ) network.append(nnx.silu) network.append( nnx.Linear( in_features=hidden_size, out_features=out_features, use_bias=use_bias, rngs=rngs, ) ) self.network = nnx.Sequential(*network) def time_embed( self, t: jax.Array, embed_dim: int, max_period: float = 10_000.0 ) -> jax.Array: assert embed_dim % 2 == 0, "embed_dim must be even" if t.shape[-1] != 1: t = t[..., None] half_dim = embed_dim // 2 freqs = jnp.exp( -jnp.log(max_period) * jnp.arange(half_dim, dtype=jnp.float32) / half_dim ) args = t * freqs time_features = jnp.concatenate([jnp.sin(args), jnp.cos(args)], axis=-1) return time_features def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array: t_encoded = self.time_embed(t, self.fourier_dim, self.fourier_max_period) x = jnp.concatenate((x, t_encoded), axis=-1) return self.network(x) # --- Diffusion functions ----------------------------- def alpha(t: jax.Array) -> jax.Array: return jnp.clip(t, 0.0, 1.0) alpha_scalar_grad = jax.grad(alpha) alpha_grad = jax.jit(jax.vmap(alpha_scalar_grad)) def beta(t: jax.Array) -> jax.Array: return 1.0 - jnp.clip(t, 0.0, 1.0) beta_scalar_grad = jax.grad(beta) beta_grad = jax.jit(jax.vmap(beta_scalar_grad)) def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array: return x_t + h * model(x_t, t) def ode_trajectory( key: jax.random.PRNGKey, model: MLP, num_samples: int, config: Config ) -> jax.Array: t = jnp.zeros((num_samples,)) h = 1.0 / config.sample_steps x = jax.random.normal(key=key, shape=(num_samples, config.space_dimensions)) for i in range(config.sample_steps): x = ode_step(model, x, t, h) t = t + h return x def sde_step(model: MLP, x_t: jax.Array, t: jax.Array, sigma_t: jax.Array) -> jax.Array: pass def sde_trajectory(model: MLP) -> jax.Array: pass # --- Training ---------------------------------------- def main(config: Config): rngs = nnx.Rngs(config.seed) model = MLP( in_features=config.space_dimensions, out_features=config.space_dimensions, num_hidden_layers=config.num_hidden_layers, hidden_size=config.hidden_size, use_bias=config.mlp_bias, fourier_features=config.fourier_dim, fourier_max_period=config.fourier_max_period, rngs=rngs, ) optim = nnx.Optimizer( model, tx=optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=3e-4)), ) @nnx.jit def train_step( model: MLP, optim: nnx.Optimizer, z: jax.Array, key: jax.random.PRNGKey ): key_e, key_t = jax.random.split(key) eps = jax.random.normal(key=key_e, shape=z.shape) t = jax.random.uniform(key=key_t, shape=[z.shape[0]]) x = alpha(t)[:, None] * z + beta(t)[:, None] * eps def loss_fn(model, z, t, eps): loss = jnp.sum( ( model(x, t) - (alpha_grad(t)[:, None] * z + beta_grad(t)[:, None] * eps) ) ** 2, axis=-1, ) return jnp.mean(loss) value_grad_fn = nnx.value_and_grad(loss_fn) loss, grads = value_grad_fn(model, z, t, eps) optim.update(grads) return loss cached_train_step = nnx.cached_partial(train_step, model, optim) for i in tqdm(range(config.num_steps)): z = sample_p_data( rngs.params(), num_samples=config.batch_size, r1=config.r1, r2=config.r2 ) _ = cached_train_step(z, rngs.params()) # Generate samples print("sampling...", end="", flush=True) samples = ode_trajectory( key=rngs.params(), model=model, num_samples=1024, config=config ) print(" done!") # samples = np.array(z) sns.scatterplot(x=samples[:, 0], y=samples[:, 1]) plt.savefig("scatter.png", dpi=300, bbox_inches="tight") if __name__ == "__main__": config = tyro.cli(Config) main(config)