-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
34 lines (29 loc) · 1.12 KB
/
main.py
File metadata and controls
34 lines (29 loc) · 1.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from inference import run_inference
from tools import load_test_cases, get_results, create_prompt, save_to_file
from evaluation import evaluate, summarize_score
import os, sys, json
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser()
if sys.argv[1] == "generate":
assert "/" in sys.argv[2], "model name should be of the format openai/gpt-35-turbo-1106"
model = sys.argv[2]
tests = load_test_cases()
results = []
for test in tqdm(tests):
outputs = []
results.append({"task_id": test["task_id"], "outputs": outputs})
prompt = create_prompt(test)
for iteration in range(3):
changes = run_inference(model, prompt)
outputs.append(changes.model_dump())
results[-1]["outputs"] = outputs
save_to_file(model, results)
if sys.argv[1] == "evaluate":
assert "/" in sys.argv[2], "model name should be of the format openai/gpt-35-turbo-1106"
results = get_results(sys.argv[2])
tests = load_test_cases()
scores = []
for test in tqdm(tests):
scores.append(evaluate(test, results))
summarize_score(scores)