-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
85 lines (69 loc) · 2.61 KB
/
main.py
File metadata and controls
85 lines (69 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python3
"""
Entry point: build panel (ELSA or --demo), train models, save metrics and ROC plots.
Run from the `elsa_ml_project` directory:
python main.py --demo
python main.py
For real ELSA runs, set ELSA_STATA_DIR to your .../stata13_se folder (see README).
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import config
from build_panel import build_lagged_panel_from_files, make_demo_panel
from train_eval import evaluate_models, metrics_to_csv, prepare_xy
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="ELSA mental health ML pipeline")
parser.add_argument(
"--demo",
action="store_true",
help="Use synthetic data only (no .dta files). Safe for Colab CI.",
)
parser.add_argument(
"--out",
type=str,
default=None,
help="Output directory for metrics and figures (default: config.OUTPUT_DIR)",
)
args = parser.parse_args(argv)
out_dir = Path(args.out) if args.out else config.OUTPUT_DIR
out_dir.mkdir(parents=True, exist_ok=True)
if args.demo:
panel, target, desc = make_demo_panel()
else:
try:
panel, target, desc = build_lagged_panel_from_files()
except FileNotFoundError as e:
print(e, file=sys.stderr)
print(
"\nTip: use --demo to verify the pipeline without data, or set "
"ELSA_STATA_DIR to your UKDA stata13_se folder.",
file=sys.stderr,
)
return 1
print(desc)
print("Panel shape:", panel.shape)
X, y, feat_cols = prepare_xy(panel, target=target)
print(f"Modelling rows: {len(y)}, features: {len(feat_cols)}")
results = evaluate_models(X, y, out_dir=out_dir)
metrics_path = out_dir / "metrics.csv"
metrics_to_csv(results, metrics_path)
with open(out_dir / "metrics.json", "w", encoding="utf-8") as f:
serialisable = {}
for k, v in results.items():
serialisable[k] = {a: b for a, b in v.items() if a != "classification_report"}
json.dump(serialisable, f, indent=2)
print("\n=== Metrics (test set) ===")
for name, r in results.items():
print(
f"{name}: ROC-AUC={r['roc_auc']:.4f} F1={r['f1']:.4f} "
f"precision={r['precision']:.4f} recall={r['recall']:.4f}"
)
print(f" Confusion matrix: {r['confusion_matrix']}")
print(f" ROC plot: {r.get('roc_path', '')}")
print(f"\nSaved: {metrics_path}")
return 0
if __name__ == "__main__":
raise SystemExit(main())