4 changed files with 223 additions and 15 deletions
@ -0,0 +1,92 @@ |
|||
#!/usr/bin/env python3 |
|||
""" |
|||
pd_plot.py: Load the aggregated experiments.csv and plot constraint satisfaction trends |
|||
for all experiments (varying number of hidden layers). |
|||
""" |
|||
|
|||
import argparse |
|||
from pathlib import Path |
|||
import pandas as pd |
|||
import seaborn as sns |
|||
import matplotlib.pyplot as plt |
|||
|
|||
|
|||
def parse_args(): |
|||
parser = argparse.ArgumentParser( |
|||
description="Plot constraint satisfaction across experiments" |
|||
) |
|||
parser.add_argument( |
|||
"--results_dir", |
|||
type=Path, |
|||
default=Path("results"), |
|||
help="Directory containing experiments.csv", |
|||
) |
|||
parser.add_argument( |
|||
"--output_dir", |
|||
type=Path, |
|||
default=Path("results"), |
|||
help="Directory to save the combined plot", |
|||
) |
|||
return parser.parse_args() |
|||
|
|||
|
|||
def main(): |
|||
args = parse_args() |
|||
csv_path = args.results_dir / "experiments.csv" |
|||
if not csv_path.exists(): |
|||
raise FileNotFoundError(f"Could not find experiments.csv at {csv_path}") |
|||
|
|||
# Load data |
|||
df = pd.read_csv(csv_path) |
|||
|
|||
# Set style |
|||
sns.set_theme(style="whitegrid", palette="pastel") |
|||
|
|||
# Create figure with two subplots |
|||
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True) |
|||
|
|||
# Plot p_donut |
|||
sns.lineplot( |
|||
data=df, |
|||
x="step", |
|||
y="p_donut", |
|||
hue="num_hidden_layers", |
|||
palette="tab10", |
|||
ax=axes[0], |
|||
) |
|||
axes[0].set_title("Donut Constraint Satisfaction") |
|||
axes[0].set_xlabel("Step") |
|||
axes[0].set_ylabel("% Constraint satisfied") |
|||
axes[0].legend(title="Hidden Layers") |
|||
|
|||
# Plot p_dent |
|||
sns.lineplot( |
|||
data=df, |
|||
x="step", |
|||
y="p_dent", |
|||
hue="num_hidden_layers", |
|||
palette="tab10", |
|||
ax=axes[1], |
|||
) |
|||
axes[1].set_title("Dent Constraint Satisfaction") |
|||
axes[1].set_xlabel("Step") |
|||
axes[1].set_ylabel("% Constraint satisfied") |
|||
axes[1].legend(title="Hidden Layers") |
|||
|
|||
# Overall title and layout |
|||
fig.suptitle( |
|||
"Constraint Satisfaction vs. Steps Across Hidden Layers", |
|||
fontsize=16, |
|||
fontweight="bold", |
|||
) |
|||
fig.tight_layout(rect=[0, 0.03, 1, 0.95]) |
|||
|
|||
# Save figure |
|||
args.output_dir.mkdir(parents=True, exist_ok=True) |
|||
out_path = args.output_dir / "all_experiments_constraint_stats.png" |
|||
fig.savefig(out_path, dpi=300, bbox_inches="tight") |
|||
print(f"Saved combined plot to {out_path}") |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
main() |
|||
Loading…
Reference in new issue