Skip to content

Commit a5ae8a0

Browse files
committed
adapted config generation script to include optimizer field ardoco#2
todo: make this field optional for backwards compatibility
1 parent 768666c commit a5ae8a0

1 file changed

Lines changed: 37 additions & 8 deletions

File tree

generate_configs_r2r.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
3+
CONFIG_DIR = "./configs/req2req"
14
TEMPLATE = """
25
{
36
"cache_dir": "./cache/<<DATASET>>",
@@ -40,11 +43,17 @@
4043
"args" : { }
4144
},
4245
"target_store" : {
43-
"name" : "custom",
46+
"name" : "cosine_similarity",
4447
"args" : {
4548
"max_results" : "<<RETRIEVAL_COUNT>>"
4649
}
4750
},
51+
"prompt_optimizer": {
52+
"name" : "simple",
53+
"args" : {
54+
"model": "codellama:13b"
55+
}
56+
},
4857
"classifier" : {
4958
"name" : "<<CLASSIFIER_MODE>>",
5059
"args" : {
@@ -64,25 +73,45 @@
6473

6574
# Configurations
6675
datasets = ["GANNT", "ModisDataset", "CCHIT", "WARC", "dronology", "CM1-NASA"]
76+
datasets = ["WARC", "dronology", "CCHIT"]
6777
postprocessors = ["req2req", "identity", "identity", "req2req", "identity", "identity"]
6878
retrieval_counts = [str(x) for x in [4, 4, 4, 4, 4, 4]]
6979

7080
classifier_modes = ["simple", "reasoning"]
7181
gpt_models = ["gpt-4o-mini-2024-07-18", "gpt-4o-2024-08-06"]
82+
gpt_models = ["o4-mini-2025-04-16"]
7283
ollama_models = ["llama3.1:8b-instruct-fp16", "codellama:13b"]
84+
ollama_models = []
7385

7486
# Generate
87+
if not os.path.exists(CONFIG_DIR):
88+
os.makedirs(CONFIG_DIR)
89+
for model in gpt_models + ollama_models:
90+
model_dir = os.path.join(CONFIG_DIR, model.replace(":", "_"))
91+
for dataset in datasets:
92+
dataset_dir = os.path.join(model_dir, dataset)
93+
if not os.path.exists(dataset_dir):
94+
os.makedirs(dataset_dir)
95+
7596
gpt_args = ["\"model\": \"<<CLASSIFIER_MODEL>>\"".replace("<<CLASSIFIER_MODEL>>", model) for model in gpt_models]
7697
ollama_args = ["\"model\": \"<<CLASSIFIER_MODEL>>\"".replace("<<CLASSIFIER_MODEL>>", model) for model in ollama_models]
77-
7898
for dataset, postprocessor, retrieval_count in zip(datasets, postprocessors, retrieval_counts):
79-
with open(f"./configs/req2req/{dataset}_no_llm.json", "w") as f:
80-
f.write(TEMPLATE.replace("<<DATASET>>", dataset).replace("<<CLASSIFIER_MODE>>", "mock").replace("<<ARGS>>", "").replace("<<POSTPROCESSOR>>", postprocessor).replace("<<RETRIEVAL_COUNT>>", retrieval_count))
99+
with open(f"./configs/req2req/{dataset}_no_llm.json", "w+") as f:
100+
f.write(TEMPLATE.replace("<<DATASET>>", dataset).replace("<<CLASSIFIER_MODE>>", "mock").replace("<<ARGS>>",
101+
"").replace(
102+
"<<POSTPROCESSOR>>", postprocessor).replace("<<RETRIEVAL_COUNT>>", retrieval_count))
81103
for classifier_mode in classifier_modes:
82104
for gpt_model, gpt_arg in zip(gpt_models, gpt_args):
83-
with open(f"./configs/req2req/{dataset}_{classifier_mode}_gpt_{gpt_model}.json", "w") as f:
84-
f.write(TEMPLATE.replace("<<DATASET>>", dataset).replace("<<CLASSIFIER_MODE>>", classifier_mode+"_openai").replace("<<ARGS>>", gpt_arg).replace("<<POSTPROCESSOR>>", postprocessor).replace("<<RETRIEVAL_COUNT>>", retrieval_count))
105+
with open(f"./configs/req2req/{gpt_model}/{dataset}/{dataset}_{classifier_mode}_gpt_{gpt_model}.json", "w+") as f:
106+
f.write(TEMPLATE.replace("<<DATASET>>", dataset).replace("<<CLASSIFIER_MODE>>",
107+
classifier_mode + "_openai").replace(
108+
"<<ARGS>>", gpt_arg).replace("<<POSTPROCESSOR>>", postprocessor).replace("<<RETRIEVAL_COUNT>>",
109+
retrieval_count))
85110

86111
for ollama_model, ollama_arg in zip(ollama_models, ollama_args):
87-
with open(f"./configs/req2req/{dataset}_{classifier_mode}_ollama_{ollama_model.replace(":", "_")}.json", "w") as f:
88-
f.write(TEMPLATE.replace("<<DATASET>>", dataset).replace("<<CLASSIFIER_MODE>>", classifier_mode+"_ollama").replace("<<ARGS>>", ollama_arg).replace("<<POSTPROCESSOR>>", postprocessor).replace("<<RETRIEVAL_COUNT>>", retrieval_count))
112+
model = ollama_model.replace(":", "_")
113+
with open(f"./configs/req2req/{model}/{dataset}/{dataset}_{classifier_mode}_ollama_{model}.json", "w+") as f:
114+
f.write(TEMPLATE.replace("<<DATASET>>", dataset).replace("<<CLASSIFIER_MODE>>",
115+
classifier_mode + "_ollama").replace(
116+
"<<ARGS>>", ollama_arg).replace("<<POSTPROCESSOR>>", postprocessor).replace("<<RETRIEVAL_COUNT>>",
117+
retrieval_count))

0 commit comments

Comments
 (0)