-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
77 lines (63 loc) · 3.33 KB
/
main.py
File metadata and controls
77 lines (63 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import multiprocessing as mp
import argparse
from datetime import datetime
from mm_detect import ModelContaminationChecker
from mm_detect.configs.config import supported_methods
from mm_detect.utils.utils import seed_everything
from mm_detect.utils.logger import setting_logger
def parse_args():
parser = argparse.ArgumentParser()
# General arguments
parser.add_argument("--caption_key", type=str, default="",
help="The caption key of each data instance.")
parser.add_argument("--image_key", type=str, default="",
help="The key to image content of each data instance.")
parser.add_argument("--dataset_name", type=str, default="",
help="If this field is set, we set train_set and eval_set to it")
parser.add_argument("--seed", type=int, default=42,
help="Random seed")
parser.add_argument("--eval_data_name", type=str, default="",
help="Eval dataset name")
parser.add_argument("--eval_data_config_name", type=str, default=None,
help="Eval dataset config name")
parser.add_argument("--eval_set_key", type=str, default="test",
help="Eval set key")
parser.add_argument("--text_key", type=str, default="",
help="The key to text content of each data instance.")
parser.add_argument("--n_eval_data_points", type=int, default=100,
help="The number of (val/test) data points to keep for evaluating contamination")
parser.add_argument("--method", type=str, choices=supported_methods.keys(),
help="you must pass a method name within the list supported_methods")
parser.add_argument("--output_dir", type=str, default="output",
help="Output directory for logging if necessary")
# Method specific-arguments for model contamination detection
### Shared across methods
parser.add_argument("--model_name", type=str, default=None,
help="Model name for service based inference.")
parser.add_argument("--max_output_tokens", type=int, default=128,
help="Max number of output tokens")
parser.add_argument("--temperature", type=float, default=0.0,
help="Temperature when sampling each sample")
# Resume functionality
parser.add_argument("--resume", action="store_true",
help="Resume from previous checkpoint if available")
args = parser.parse_args()
# Setting global logger name
# current_date = datetime.now().strftime('%Y%m%d_%H%M%S')
data = args.dataset_name if args.dataset_name != "" else args.eval_data_name
data = data.replace("/", "_")
log_file_name = f"{data}_{args.n_eval_data_points}.txt"
logger = setting_logger(log_file_name, args.output_dir)
logger.warning(args)
return args
def check_args(args):
assert args.method in supported_methods, f"Error, {args.method} not in supported methods: {list(supported_methods.keys())}"
def main():
args = parse_args()
check_args(args)
seed_everything(args.seed)
ContaminationChecker = ModelContaminationChecker
contamination_checker = ContaminationChecker(args)
contamination_checker.run_contamination(args.method)
if __name__ == '__main__':
main()