-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
30 lines (21 loc) · 946 Bytes
/
eval.py
File metadata and controls
30 lines (21 loc) · 946 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from gsm8k import GSM8K_evaluation
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from model_wrapper import TokenWrapper
import argparse
# gather args
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, required=True)
parser.add_argument('--formatted', type=bool, default=True)
args = parser.parse_args()
model_name = args.model_name
formatted = args.formatted
evaluator = GSM8K_evaluation(formatted=formatted)
# load model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model_wrapped = TokenWrapper(model, tokenizer)
output_file = model_name.split("/")[-1] + "_results.json"
deepseek_prompt = 'Please reason step by step, and put your final answer within \boxed{}.'
# evaluate model
acc, corrects = evaluator.eval(model_wrapped, n_evals=1, output_file=output_file, format_instructions=deepseek_prompt)