You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

92 lines
2.4 KiB

#!/usr/bin/env python3
"""
pd_plot.py: Load the aggregated experiments.csv and plot constraint satisfaction trends
for all experiments (varying number of hidden layers).
"""
import argparse
from pathlib import Path
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def parse_args():
parser = argparse.ArgumentParser(
description="Plot constraint satisfaction across experiments"
)
parser.add_argument(
"--results_dir",
type=Path,
default=Path("results"),
help="Directory containing experiments.csv",
)
parser.add_argument(
"--output_dir",
type=Path,
default=Path("results"),
help="Directory to save the combined plot",
)
return parser.parse_args()
def main():
args = parse_args()
csv_path = args.results_dir / "experiments.csv"
if not csv_path.exists():
raise FileNotFoundError(f"Could not find experiments.csv at {csv_path}")
# Load data
df = pd.read_csv(csv_path)
# Set style
sns.set_theme(style="whitegrid", palette="pastel")
# Create figure with two subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
# Plot p_donut
sns.lineplot(
data=df,
x="step",
y="p_donut",
hue="num_hidden_layers",
palette="tab10",
ax=axes[0],
)
axes[0].set_title("Donut Constraint Satisfaction")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("% Constraint satisfied")
axes[0].legend(title="Hidden Layers")
# Plot p_dent
sns.lineplot(
data=df,
x="step",
y="p_dent",
hue="num_hidden_layers",
palette="tab10",
ax=axes[1],
)
axes[1].set_title("Dent Constraint Satisfaction")
axes[1].set_xlabel("Step")
axes[1].set_ylabel("% Constraint satisfied")
axes[1].legend(title="Hidden Layers")
# Overall title and layout
fig.suptitle(
"Constraint Satisfaction vs. Steps Across Hidden Layers",
fontsize=16,
fontweight="bold",
)
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
# Save figure
args.output_dir.mkdir(parents=True, exist_ok=True)
out_path = args.output_dir / "all_experiments_constraint_stats.png"
fig.savefig(out_path, dpi=300, bbox_inches="tight")
print(f"Saved combined plot to {out_path}")
if __name__ == "__main__":
main()