-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathplot_trials.py
More file actions
203 lines (183 loc) · 7.06 KB
/
plot_trials.py
File metadata and controls
203 lines (183 loc) · 7.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import argparse
import csv
import json
import logging
import os
from pathlib import Path
import matplotlib.pyplot as plt
from alno.utils.loading import setup_logging
from alno.utils.plotting import (
plot_averaged_cumulative_regret,
plot_cumulative_max,
plot_cumulative_regret,
plot_summary_across_trials,
)
from run_trials import EXPERIMENT_SCRIPTS
def load_results(results_dirs, algorithms, n_trials):
"""
Load outcome and regret data from all specified directories and trials.
Returns:
algo_outcomes: dict of algo -> list of list of outcomes per trial
algo_regret: dict of algo -> list of list of regrets per trial
"""
algo_outcomes = {algo.lower(): [] for algo in algorithms}
algo_regret = {algo.lower(): [] for algo in algorithms}
for results_dir in results_dirs:
for trial in range(1, n_trials + 1):
trial_dir = Path(results_dir) / f"trial_{trial}"
for algo in algorithms:
algo = algo.lower()
csv_path = trial_dir / f"pool_{algo}.csv"
if csv_path.exists():
outcomes = []
regret = []
with open(csv_path, "r") as f:
reader = csv.DictReader(f)
for row in reader:
try:
outcomes.append(float(row["outcome"]))
regret.append(float(row["regret"]))
except Exception:
logger.warning(
f"Skipping invalid row in {csv_path}: {row}"
)
continue
if outcomes:
# Append to algo list; for multiple directories, last overwrite is naturally handled
if len(algo_outcomes[algo]) < trial:
algo_outcomes[algo].append(outcomes)
algo_regret[algo].append(regret)
else:
# Overwrite for current trial
algo_outcomes[algo][trial - 1] = outcomes
algo_regret[algo][trial - 1] = regret
else:
logger.warning(f"No valid outcomes in {csv_path}")
else:
logger.warning(f"CSV does not exist: {csv_path}")
return algo_outcomes, algo_regret
def main(
results_dirs,
algorithms,
n_trials,
max_iterations,
figdpi,
n_std,
truncate_to_shortest,
):
# Save the directories used into a JSON file
# Load data
logger.info("Loading results from directories: %s", results_dirs)
algo_outcomes, algo_regret = load_results(results_dirs, algorithms, n_trials)
# Create output directory if it doesn't exist
out_dir = args.output_dir
if out_dir is None:
out_dir = results_dirs[0]
os.makedirs(out_dir, exist_ok=True)
# Save the directories used into a JSON file
json_path = os.path.join(out_dir, "plot_trials_data_dirs.json")
with open(json_path, "w") as json_file:
json.dump({"results_dirs": results_dirs, "n_std": n_std}, json_file)
logger.info(f"Saved list of data directories to {json_path}")
# Plotting
colors = plt.get_cmap("tab10").colors
figsize = (5, 3)
plt.figure(figsize=figsize)
plot_summary_across_trials(algo_outcomes, colors, n_std=n_std)
plt.legend()
plt.title("Average outcomes across trials")
plt.xlim(0, max_iterations)
output_path = os.path.join(out_dir, "summary_across_trials.png")
plt.savefig(output_path, dpi=figdpi)
logger.info("Saved plot: %s", output_path)
plt.close()
plt.figure(figsize=figsize)
plot_cumulative_max(algo_outcomes, colors, n_std=n_std)
plt.legend()
plt.title("Cumulative max (best-so-far) across trials")
plt.xlim(0, max_iterations)
output_path = os.path.join(out_dir, "cumulative_max_across_trials.png")
plt.savefig(output_path, dpi=figdpi)
logger.info("Saved plot: %s", output_path)
plt.close()
plt.figure(figsize=figsize)
plot_cumulative_regret(
algo_regret, colors, n_std=n_std, truncate_to_shortest=truncate_to_shortest
)
plt.legend()
plt.title("Cumulative regret across trials")
plt.xlim(0, max_iterations)
output_path = os.path.join(out_dir, "cumulative_regret_across_trials.png")
plt.savefig(output_path, dpi=figdpi)
logger.info("Saved plot: %s", output_path)
plt.close()
plt.figure(figsize=figsize)
regrets_summary = plot_averaged_cumulative_regret(
algo_regret, colors, n_std=n_std, truncate_to_shortest=truncate_to_shortest
)
plt.legend()
plt.title("Averaged cumulative regret across trials")
plt.xlim(0, max_iterations)
output_path = os.path.join(out_dir, "averaged_cumulative_regret.png")
plt.savefig(output_path, dpi=figdpi)
logger.info("Saved plot: %s", output_path)
plt.close()
# Save regrets summary to JSON
summary_json_path = os.path.join(out_dir, "average_regrets.json")
with open(summary_json_path, "w") as json_file:
json.dump(regrets_summary, json_file, indent=4)
logger.info(f"Saved regrets summary to {summary_json_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate summary plots from run_trials results"
)
parser.add_argument(
"results_dirs",
nargs="+",
type=str,
help="Paths to one or more results directories",
)
parser.add_argument(
"--algorithms", nargs="+", type=str, default=None, help="List of algorithms"
)
parser.add_argument("--n-trials", type=int, default=10, help="Number of trials")
parser.add_argument(
"--max-iterations", type=int, default=None, help="Maximum number of iterations"
)
parser.add_argument("--figdpi", type=int, default=300, help="DPI for the plot")
parser.add_argument(
"--n-std",
type=int,
default=1,
help="Number of standard deviations for confidence intervals",
)
parser.add_argument(
"--truncate-to-shortest",
action="store_true",
help="Truncate data to shortest trial, instead of padding",
)
parser.add_argument(
"-o", "--output-dir", type=str, default=None, help="Output directory"
)
args = parser.parse_args()
setup_logging()
logger = logging.getLogger(__name__)
# If algorithms not provided, infer from CSV filenames or set default
if args.algorithms is None:
algorithms = list(EXPERIMENT_SCRIPTS.keys())
logger.info(f"Inferred algorithms: {algorithms}")
else:
algorithms = args.algorithms
logger.info(f"Using specified algorithms: {algorithms}")
try:
main(
args.results_dirs,
algorithms,
args.n_trials,
args.max_iterations,
args.figdpi,
args.n_std,
args.truncate_to_shortest,
)
except Exception:
logger.exception("Unexpected error in main function")