Browse Source

feat: experiments running

master
CALVO GONZALEZ Ramon 7 months ago
parent
commit
30c0bd30c0
  1. 92
      pd_plots.py
  2. 2
      pyproject.toml
  3. 131
      train.py
  4. 13
      uv.lock

92
pd_plots.py

@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""
pd_plot.py: Load the aggregated experiments.csv and plot constraint satisfaction trends
for all experiments (varying number of hidden layers).
"""
import argparse
from pathlib import Path
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def parse_args():
parser = argparse.ArgumentParser(
description="Plot constraint satisfaction across experiments"
)
parser.add_argument(
"--results_dir",
type=Path,
default=Path("results"),
help="Directory containing experiments.csv",
)
parser.add_argument(
"--output_dir",
type=Path,
default=Path("results"),
help="Directory to save the combined plot",
)
return parser.parse_args()
def main():
args = parse_args()
csv_path = args.results_dir / "experiments.csv"
if not csv_path.exists():
raise FileNotFoundError(f"Could not find experiments.csv at {csv_path}")
# Load data
df = pd.read_csv(csv_path)
# Set style
sns.set_theme(style="whitegrid", palette="pastel")
# Create figure with two subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
# Plot p_donut
sns.lineplot(
data=df,
x="step",
y="p_donut",
hue="num_hidden_layers",
palette="tab10",
ax=axes[0],
)
axes[0].set_title("Donut Constraint Satisfaction")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("% Constraint satisfied")
axes[0].legend(title="Hidden Layers")
# Plot p_dent
sns.lineplot(
data=df,
x="step",
y="p_dent",
hue="num_hidden_layers",
palette="tab10",
ax=axes[1],
)
axes[1].set_title("Dent Constraint Satisfaction")
axes[1].set_xlabel("Step")
axes[1].set_ylabel("% Constraint satisfied")
axes[1].legend(title="Hidden Layers")
# Overall title and layout
fig.suptitle(
"Constraint Satisfaction vs. Steps Across Hidden Layers",
fontsize=16,
fontweight="bold",
)
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
# Save figure
args.output_dir.mkdir(parents=True, exist_ok=True)
out_path = args.output_dir / "all_experiments_constraint_stats.png"
fig.savefig(out_path, dpi=300, bbox_inches="tight")
print(f"Saved combined plot to {out_path}")
if __name__ == "__main__":
main()

2
pyproject.toml

@ -7,10 +7,12 @@ requires-python = ">=3.11"
dependencies = [
"distrax>=0.1.5",
"einops>=0.8.1",
"filelock>=3.18.0",
"flax>=0.10.6",
"jax[cuda12]>=0.6.0",
"numpy>=2.2.5",
"orbax>=0.1.9",
"pandas>=2.2.3",
"seaborn>=0.13.2",
"tqdm>=4.67.1",
"tyro>=0.9.19",

131
train.py

@ -1,16 +1,22 @@
from typing import Optional
import os
from pathlib import Path
import tyro
from functools import partial
from dataclasses import dataclass
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from filelock import FileLock, Timeout
import jax
import jax.numpy as jnp
import flax.nnx as nnx
import optax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1"
@dataclass
class Config:
@ -37,7 +43,7 @@ class Config:
num_steps: int = 100_000
"""How many steps of gradient descent to perform."""
batch_size: int = 512
batch_size: int = 256
"""How many samples per mini-batch."""
r1: float = 0.3
@ -55,6 +61,12 @@ class Config:
evaluate_constraints_every: int = 0
"""Evaluate the constraints after given steps."""
save_scatterplot: bool = False
"""For every step the constraints are evaluated, save a scatterplot."""
show_p_data: bool = False
"""If set, the script will generate a scatterplot sampled from the p_data process."""
seed: int = 42
"""The seed used for randomness."""
@ -109,16 +121,32 @@ def array_bool_statistics(x: jax.Array) -> float:
def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0):
p_data_key, p_t_key = jax.random.split(key)
samples = ode_trajectory(p_t_key, model=model, num_samples=1024, config=config)
samples = ode_trajectory(
p_t_key,
model=model,
num_samples=32_768,
sample_steps=config.sample_steps,
space_dimensions=config.space_dimensions,
)
stats_donut = array_bool_statistics(
check_donut_constraint(samples, r1=config.r1, r2=config.r2)
jnp.logical_not(check_donut_constraint(samples, r1=config.r1, r2=config.r2))
)
stats_dent = array_bool_statistics(
jnp.logical_not(check_dent_constraint(samples, d=config.d))
)
stats_dent = array_bool_statistics(check_dent_constraint(samples, d=config.d))
if save_plot is not None:
sns.scatterplot(x=samples[:, 0], y=samples[:, 1])
plt.savefig(f"scatter_{save_plot}.png", dpi=300, bbox_inches="tight")
plt.figure()
sns.scatterplot(x=samples[:, 0], y=samples[:, 1], size=0.1)
save_folder = (
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}"
)
save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig(
f"{save_folder}/scatter_{save_plot}.png", dpi=300, bbox_inches="tight"
)
plt.close()
return stats_donut, stats_dent
@ -213,16 +241,28 @@ def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array:
return x_t + h * model(x_t, t)
@partial(nnx.jit, static_argnums=(2, 3, 4))
def ode_trajectory(
key: jax.random.PRNGKey, model: MLP, num_samples: int, config: Config
key: jax.random.PRNGKey,
model: MLP,
num_samples: int,
sample_steps: int,
space_dimensions: int,
) -> jax.Array:
t = jnp.zeros((num_samples,))
h = 1.0 / config.sample_steps
x = jax.random.normal(key=key, shape=(num_samples, config.space_dimensions))
h = 1.0 / sample_steps
x = jax.random.normal(key=key, shape=(num_samples, space_dimensions))
for i in range(config.sample_steps):
def body(i, state):
t, x = state
x = ode_step(model, x, t, h)
t = t + h
return (t + h, x)
_, x = jax.lax.fori_loop(0, config.sample_steps, body, (t, x))
# for i in range(sample_steps):
# x = ode_step(model, x, t, h)
# t = t + h
return x
@ -241,6 +281,15 @@ def sde_trajectory(model: MLP) -> jax.Array:
def main(config: Config):
rngs = nnx.Rngs(config.seed)
if config.show_p_data:
points = sample_p_data(rngs.params(), 32_768, config.r1, config.r2, config.d)
plt.figure()
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1)
if not os.path.exists("./results"):
os.mkdir("results")
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight")
exit()
model = MLP(
in_features=config.space_dimensions,
out_features=config.space_dimensions,
@ -284,6 +333,8 @@ def main(config: Config):
cached_train_step = nnx.cached_partial(train_step, model, optim)
a_donut = (jnp.pi - 0.5 * config.d) * (config.r2**2 - config.r1**2)
a_dent = 0.5 * config.d * (config.r2**2 - config.r1**2)
stats_donut = []
stats_dent = []
for i in tqdm(range(config.num_steps)):
@ -301,9 +352,54 @@ def main(config: Config):
config.evaluate_constraints_every != 0
and i % config.evaluate_constraints_every == 0
):
stat_donut, stat_dent = evaluate_constraints(rngs.params(), model)
stats_donut.append(stat_donut.item())
stats_dent.append(stat_dent.item())
stat_donut, stat_dent = evaluate_constraints(
rngs.params(), model, save_plot=i if config.save_scatterplot else None
)
stats_donut.append(stat_donut.item() / a_donut)
stats_dent.append(stat_dent.item() / a_dent)
try:
with FileLock(os.path.join("results", "experiments.lock")):
file = os.path.join("results", "experiments.csv")
if os.path.exists(file):
df = pd.read_csv(file)
else:
df = pd.DataFrame(
columns=[
"num_hidden_layers",
"hidden_size",
"mlp_bias",
"fourier_dim",
"r1",
"r2",
"d",
"sample_steps",
"p_donut",
"p_dent",
"step",
]
)
new_row = pd.DataFrame(
{
"num_hidden_layers": [config.num_hidden_layers],
"hidden_size": [config.hidden_size],
"mlp_bias": [config.mlp_bias],
"fourier_dim": [config.fourier_dim],
"r1": [config.r1],
"r2": [config.r2],
"d": [config.d],
"sample_steps": [config.sample_steps],
"p_donut": [stats_donut[-1]],
"p_dent": [stats_dent[-1]],
"step": [i],
}
)
df = pd.concat([df, new_row], ignore_index=True)
df.to_csv(file, index=False)
except Timeout:
print("Timeout!!!")
# Plot results
# 1. Set a “pretty” style
@ -330,7 +426,12 @@ def main(config: Config):
fig.suptitle("Comparison of Donut vs. Dent Trends", fontsize=16, fontweight="bold")
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig("constraint_stats.png")
save_folder = (
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}"
)
save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig(os.path.join(save_folder, "constraint_stats.png"))
plt.close()
if __name__ == "__main__":

13
uv.lock

@ -220,10 +220,12 @@ source = { virtual = "." }
dependencies = [
{ name = "distrax" },
{ name = "einops" },
{ name = "filelock" },
{ name = "flax" },
{ name = "jax", extra = ["cuda12"] },
{ name = "numpy" },
{ name = "orbax" },
{ name = "pandas" },
{ name = "seaborn" },
{ name = "tqdm" },
{ name = "tyro" },
@ -234,10 +236,12 @@ dependencies = [
requires-dist = [
{ name = "distrax", specifier = ">=0.1.5" },
{ name = "einops", specifier = ">=0.8.1" },
{ name = "filelock", specifier = ">=3.18.0" },
{ name = "flax", specifier = ">=0.10.6" },
{ name = "jax", extras = ["cuda12"], specifier = ">=0.6.0" },
{ name = "numpy", specifier = ">=2.2.5" },
{ name = "orbax", specifier = ">=0.1.9" },
{ name = "pandas", specifier = ">=2.2.3" },
{ name = "seaborn", specifier = ">=0.13.2" },
{ name = "tqdm", specifier = ">=4.67.1" },
{ name = "tyro", specifier = ">=0.9.19" },
@ -341,6 +345,15 @@ epy = [
{ name = "typing-extensions" },
]
[[package]]
name = "filelock"
version = "3.18.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/0a/10/c23352565a6544bdc5353e0b15fc1c563352101f30e24bf500207a54df9a/filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2", size = 18075 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de", size = 16215 },
]
[[package]]
name = "flax"
version = "0.10.6"

Loading…
Cancel
Save