Browse Source

feat: experiments running

master
CALVO GONZALEZ Ramon 9 months ago
parent
commit
30c0bd30c0
  1. 92
      pd_plots.py
  2. 2
      pyproject.toml
  3. 131
      train.py
  4. 13
      uv.lock

92
pd_plots.py

@ -0,0 +1,92 @@
#!/usr/bin/env python3
"""
pd_plot.py: Load the aggregated experiments.csv and plot constraint satisfaction trends
for all experiments (varying number of hidden layers).
"""
import argparse
from pathlib import Path
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def parse_args():
parser = argparse.ArgumentParser(
description="Plot constraint satisfaction across experiments"
)
parser.add_argument(
"--results_dir",
type=Path,
default=Path("results"),
help="Directory containing experiments.csv",
)
parser.add_argument(
"--output_dir",
type=Path,
default=Path("results"),
help="Directory to save the combined plot",
)
return parser.parse_args()
def main():
args = parse_args()
csv_path = args.results_dir / "experiments.csv"
if not csv_path.exists():
raise FileNotFoundError(f"Could not find experiments.csv at {csv_path}")
# Load data
df = pd.read_csv(csv_path)
# Set style
sns.set_theme(style="whitegrid", palette="pastel")
# Create figure with two subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
# Plot p_donut
sns.lineplot(
data=df,
x="step",
y="p_donut",
hue="num_hidden_layers",
palette="tab10",
ax=axes[0],
)
axes[0].set_title("Donut Constraint Satisfaction")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("% Constraint satisfied")
axes[0].legend(title="Hidden Layers")
# Plot p_dent
sns.lineplot(
data=df,
x="step",
y="p_dent",
hue="num_hidden_layers",
palette="tab10",
ax=axes[1],
)
axes[1].set_title("Dent Constraint Satisfaction")
axes[1].set_xlabel("Step")
axes[1].set_ylabel("% Constraint satisfied")
axes[1].legend(title="Hidden Layers")
# Overall title and layout
fig.suptitle(
"Constraint Satisfaction vs. Steps Across Hidden Layers",
fontsize=16,
fontweight="bold",
)
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
# Save figure
args.output_dir.mkdir(parents=True, exist_ok=True)
out_path = args.output_dir / "all_experiments_constraint_stats.png"
fig.savefig(out_path, dpi=300, bbox_inches="tight")
print(f"Saved combined plot to {out_path}")
if __name__ == "__main__":
main()

2
pyproject.toml

@ -7,10 +7,12 @@ requires-python = ">=3.11"
dependencies = [ dependencies = [
"distrax>=0.1.5", "distrax>=0.1.5",
"einops>=0.8.1", "einops>=0.8.1",
"filelock>=3.18.0",
"flax>=0.10.6", "flax>=0.10.6",
"jax[cuda12]>=0.6.0", "jax[cuda12]>=0.6.0",
"numpy>=2.2.5", "numpy>=2.2.5",
"orbax>=0.1.9", "orbax>=0.1.9",
"pandas>=2.2.3",
"seaborn>=0.13.2", "seaborn>=0.13.2",
"tqdm>=4.67.1", "tqdm>=4.67.1",
"tyro>=0.9.19", "tyro>=0.9.19",

131
train.py

@ -1,16 +1,22 @@
from typing import Optional from typing import Optional
import os
from pathlib import Path
import tyro import tyro
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
from tqdm import tqdm from tqdm import tqdm
import seaborn as sns import seaborn as sns
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd
from filelock import FileLock, Timeout
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import flax.nnx as nnx import flax.nnx as nnx
import optax import optax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.1"
@dataclass @dataclass
class Config: class Config:
@ -37,7 +43,7 @@ class Config:
num_steps: int = 100_000 num_steps: int = 100_000
"""How many steps of gradient descent to perform.""" """How many steps of gradient descent to perform."""
batch_size: int = 512 batch_size: int = 256
"""How many samples per mini-batch.""" """How many samples per mini-batch."""
r1: float = 0.3 r1: float = 0.3
@ -55,6 +61,12 @@ class Config:
evaluate_constraints_every: int = 0 evaluate_constraints_every: int = 0
"""Evaluate the constraints after given steps.""" """Evaluate the constraints after given steps."""
save_scatterplot: bool = False
"""For every step the constraints are evaluated, save a scatterplot."""
show_p_data: bool = False
"""If set, the script will generate a scatterplot sampled from the p_data process."""
seed: int = 42 seed: int = 42
"""The seed used for randomness.""" """The seed used for randomness."""
@ -109,16 +121,32 @@ def array_bool_statistics(x: jax.Array) -> float:
def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0): def evaluate_constraints(key: jax.random.PRNGKey, model, save_plot: Optional[int] = 0):
p_data_key, p_t_key = jax.random.split(key) p_data_key, p_t_key = jax.random.split(key)
samples = ode_trajectory(p_t_key, model=model, num_samples=1024, config=config) samples = ode_trajectory(
p_t_key,
model=model,
num_samples=32_768,
sample_steps=config.sample_steps,
space_dimensions=config.space_dimensions,
)
stats_donut = array_bool_statistics( stats_donut = array_bool_statistics(
check_donut_constraint(samples, r1=config.r1, r2=config.r2) jnp.logical_not(check_donut_constraint(samples, r1=config.r1, r2=config.r2))
)
stats_dent = array_bool_statistics(
jnp.logical_not(check_dent_constraint(samples, d=config.d))
) )
stats_dent = array_bool_statistics(check_dent_constraint(samples, d=config.d))
if save_plot is not None: if save_plot is not None:
sns.scatterplot(x=samples[:, 0], y=samples[:, 1]) plt.figure()
plt.savefig(f"scatter_{save_plot}.png", dpi=300, bbox_inches="tight") sns.scatterplot(x=samples[:, 0], y=samples[:, 1], size=0.1)
save_folder = (
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}"
)
save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig(
f"{save_folder}/scatter_{save_plot}.png", dpi=300, bbox_inches="tight"
)
plt.close()
return stats_donut, stats_dent return stats_donut, stats_dent
@ -213,16 +241,28 @@ def ode_step(model: MLP, x_t: jax.Array, t: jax.Array, h: float) -> jax.Array:
return x_t + h * model(x_t, t) return x_t + h * model(x_t, t)
@partial(nnx.jit, static_argnums=(2, 3, 4))
def ode_trajectory( def ode_trajectory(
key: jax.random.PRNGKey, model: MLP, num_samples: int, config: Config key: jax.random.PRNGKey,
model: MLP,
num_samples: int,
sample_steps: int,
space_dimensions: int,
) -> jax.Array: ) -> jax.Array:
t = jnp.zeros((num_samples,)) t = jnp.zeros((num_samples,))
h = 1.0 / config.sample_steps h = 1.0 / sample_steps
x = jax.random.normal(key=key, shape=(num_samples, config.space_dimensions)) x = jax.random.normal(key=key, shape=(num_samples, space_dimensions))
for i in range(config.sample_steps): def body(i, state):
t, x = state
x = ode_step(model, x, t, h) x = ode_step(model, x, t, h)
t = t + h return (t + h, x)
_, x = jax.lax.fori_loop(0, config.sample_steps, body, (t, x))
# for i in range(sample_steps):
# x = ode_step(model, x, t, h)
# t = t + h
return x return x
@ -241,6 +281,15 @@ def sde_trajectory(model: MLP) -> jax.Array:
def main(config: Config): def main(config: Config):
rngs = nnx.Rngs(config.seed) rngs = nnx.Rngs(config.seed)
if config.show_p_data:
points = sample_p_data(rngs.params(), 32_768, config.r1, config.r2, config.d)
plt.figure()
sns.scatterplot(x=points[:, 0], y=points[:, 1], size=0.1)
if not os.path.exists("./results"):
os.mkdir("results")
plt.savefig("results/p_data.png", dpi=300, bbox_inches="tight")
exit()
model = MLP( model = MLP(
in_features=config.space_dimensions, in_features=config.space_dimensions,
out_features=config.space_dimensions, out_features=config.space_dimensions,
@ -284,6 +333,8 @@ def main(config: Config):
cached_train_step = nnx.cached_partial(train_step, model, optim) cached_train_step = nnx.cached_partial(train_step, model, optim)
a_donut = (jnp.pi - 0.5 * config.d) * (config.r2**2 - config.r1**2)
a_dent = 0.5 * config.d * (config.r2**2 - config.r1**2)
stats_donut = [] stats_donut = []
stats_dent = [] stats_dent = []
for i in tqdm(range(config.num_steps)): for i in tqdm(range(config.num_steps)):
@ -301,9 +352,54 @@ def main(config: Config):
config.evaluate_constraints_every != 0 config.evaluate_constraints_every != 0
and i % config.evaluate_constraints_every == 0 and i % config.evaluate_constraints_every == 0
): ):
stat_donut, stat_dent = evaluate_constraints(rngs.params(), model) stat_donut, stat_dent = evaluate_constraints(
stats_donut.append(stat_donut.item()) rngs.params(), model, save_plot=i if config.save_scatterplot else None
stats_dent.append(stat_dent.item()) )
stats_donut.append(stat_donut.item() / a_donut)
stats_dent.append(stat_dent.item() / a_dent)
try:
with FileLock(os.path.join("results", "experiments.lock")):
file = os.path.join("results", "experiments.csv")
if os.path.exists(file):
df = pd.read_csv(file)
else:
df = pd.DataFrame(
columns=[
"num_hidden_layers",
"hidden_size",
"mlp_bias",
"fourier_dim",
"r1",
"r2",
"d",
"sample_steps",
"p_donut",
"p_dent",
"step",
]
)
new_row = pd.DataFrame(
{
"num_hidden_layers": [config.num_hidden_layers],
"hidden_size": [config.hidden_size],
"mlp_bias": [config.mlp_bias],
"fourier_dim": [config.fourier_dim],
"r1": [config.r1],
"r2": [config.r2],
"d": [config.d],
"sample_steps": [config.sample_steps],
"p_donut": [stats_donut[-1]],
"p_dent": [stats_dent[-1]],
"step": [i],
}
)
df = pd.concat([df, new_row], ignore_index=True)
df.to_csv(file, index=False)
except Timeout:
print("Timeout!!!")
# Plot results # Plot results
# 1. Set a “pretty” style # 1. Set a “pretty” style
@ -330,7 +426,12 @@ def main(config: Config):
fig.suptitle("Comparison of Donut vs. Dent Trends", fontsize=16, fontweight="bold") fig.suptitle("Comparison of Donut vs. Dent Trends", fontsize=16, fontweight="bold")
fig.tight_layout(rect=[0, 0.03, 1, 0.95]) fig.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig("constraint_stats.png") save_folder = (
Path("results") / f"mlp_l{config.num_hidden_layers}_h{config.hidden_size}"
)
save_folder.mkdir(parents=True, exist_ok=True)
plt.savefig(os.path.join(save_folder, "constraint_stats.png"))
plt.close()
if __name__ == "__main__": if __name__ == "__main__":

13
uv.lock

@ -220,10 +220,12 @@ source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "distrax" }, { name = "distrax" },
{ name = "einops" }, { name = "einops" },
{ name = "filelock" },
{ name = "flax" }, { name = "flax" },
{ name = "jax", extra = ["cuda12"] }, { name = "jax", extra = ["cuda12"] },
{ name = "numpy" }, { name = "numpy" },
{ name = "orbax" }, { name = "orbax" },
{ name = "pandas" },
{ name = "seaborn" }, { name = "seaborn" },
{ name = "tqdm" }, { name = "tqdm" },
{ name = "tyro" }, { name = "tyro" },
@ -234,10 +236,12 @@ dependencies = [
requires-dist = [ requires-dist = [
{ name = "distrax", specifier = ">=0.1.5" }, { name = "distrax", specifier = ">=0.1.5" },
{ name = "einops", specifier = ">=0.8.1" }, { name = "einops", specifier = ">=0.8.1" },
{ name = "filelock", specifier = ">=3.18.0" },
{ name = "flax", specifier = ">=0.10.6" }, { name = "flax", specifier = ">=0.10.6" },
{ name = "jax", extras = ["cuda12"], specifier = ">=0.6.0" }, { name = "jax", extras = ["cuda12"], specifier = ">=0.6.0" },
{ name = "numpy", specifier = ">=2.2.5" }, { name = "numpy", specifier = ">=2.2.5" },
{ name = "orbax", specifier = ">=0.1.9" }, { name = "orbax", specifier = ">=0.1.9" },
{ name = "pandas", specifier = ">=2.2.3" },
{ name = "seaborn", specifier = ">=0.13.2" }, { name = "seaborn", specifier = ">=0.13.2" },
{ name = "tqdm", specifier = ">=4.67.1" }, { name = "tqdm", specifier = ">=4.67.1" },
{ name = "tyro", specifier = ">=0.9.19" }, { name = "tyro", specifier = ">=0.9.19" },
@ -341,6 +345,15 @@ epy = [
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
[[package]]
name = "filelock"
version = "3.18.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/0a/10/c23352565a6544bdc5353e0b15fc1c563352101f30e24bf500207a54df9a/filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2", size = 18075 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de", size = 16215 },
]
[[package]] [[package]]
name = "flax" name = "flax"
version = "0.10.6" version = "0.10.6"

Loading…
Cancel
Save