You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
2.4 KiB
92 lines
2.4 KiB
#!/usr/bin/env python3
|
|
"""
|
|
pd_plot.py: Load the aggregated experiments.csv and plot constraint satisfaction trends
|
|
for all experiments (varying number of hidden layers).
|
|
"""
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Plot constraint satisfaction across experiments"
|
|
)
|
|
parser.add_argument(
|
|
"--results_dir",
|
|
type=Path,
|
|
default=Path("results"),
|
|
help="Directory containing experiments.csv",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=Path,
|
|
default=Path("results"),
|
|
help="Directory to save the combined plot",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
csv_path = args.results_dir / "experiments.csv"
|
|
if not csv_path.exists():
|
|
raise FileNotFoundError(f"Could not find experiments.csv at {csv_path}")
|
|
|
|
# Load data
|
|
df = pd.read_csv(csv_path)
|
|
|
|
# Set style
|
|
sns.set_theme(style="whitegrid", palette="pastel")
|
|
|
|
# Create figure with two subplots
|
|
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
|
|
|
|
# Plot p_donut
|
|
sns.lineplot(
|
|
data=df,
|
|
x="step",
|
|
y="p_donut",
|
|
hue="num_hidden_layers",
|
|
palette="tab10",
|
|
ax=axes[0],
|
|
)
|
|
axes[0].set_title("Donut Constraint Satisfaction")
|
|
axes[0].set_xlabel("Step")
|
|
axes[0].set_ylabel("% Constraint satisfied")
|
|
axes[0].legend(title="Hidden Layers")
|
|
|
|
# Plot p_dent
|
|
sns.lineplot(
|
|
data=df,
|
|
x="step",
|
|
y="p_dent",
|
|
hue="num_hidden_layers",
|
|
palette="tab10",
|
|
ax=axes[1],
|
|
)
|
|
axes[1].set_title("Dent Constraint Satisfaction")
|
|
axes[1].set_xlabel("Step")
|
|
axes[1].set_ylabel("% Constraint satisfied")
|
|
axes[1].legend(title="Hidden Layers")
|
|
|
|
# Overall title and layout
|
|
fig.suptitle(
|
|
"Constraint Satisfaction vs. Steps Across Hidden Layers",
|
|
fontsize=16,
|
|
fontweight="bold",
|
|
)
|
|
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
|
|
|
|
# Save figure
|
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
|
out_path = args.output_dir / "all_experiments_constraint_stats.png"
|
|
fig.savefig(out_path, dpi=300, bbox_inches="tight")
|
|
print(f"Saved combined plot to {out_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|