-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
70 lines (62 loc) · 2.46 KB
/
utils.py
File metadata and controls
70 lines (62 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# utils.py
import os
import random
import numpy as np
import torch
def set_global_seed(seed: int):
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import re
def parse_model_ckpt_name(ckpt_name, replacement_type):
if replacement_type == "rotor":
pattern = (
r"model_chunk=(\d+)_layers=(\d+)_breadth=(\d+)"
r"_lr=([0-9.]+)_weightdecay=([0-9]+(?:\.[0-9]+)?)_batchsize=(\d+)"
r"_nonlinear=([a-zA-Z]+)_normalize=(True|False)_residual=(True|False)"
r"_useperm=(True|False)_epochs=(\d+)_cosanneal=(True|False)_projparallel=(True|False)_singlerotor=(True|False)"
)
match = re.match(pattern, ckpt_name.split(".pth")[0])
if not match:
raise ValueError("Checkpoint filename does not match expected format")
(chunk_size, layers, breadth,
lr, weight_decay, batch_size,
nonlinear, normalize, residual,
use_perm, epochs, cos_anneal, proj_parallel, single_rotor) = match.groups()
return {
"chunk_size": int(chunk_size),
"hidden_layers": int(layers),
"breadth_hidden": int(breadth),
"lr": float(lr),
"weight_decay": float(weight_decay),
"batch_size": int(batch_size),
"nonlinear": nonlinear,
"normalize": normalize == "True",
"residual": residual == "True",
"use_perm": use_perm == "True",
"epochs": int(epochs),
"cos_annealing": cos_anneal == "True",
"proj_parallel": proj_parallel == "True",
"single_rotor": single_rotor == "True",
}
elif replacement_type == "lowrank_linear":
pattern = (
r"lowrank_linear_rank=(\d+)_lr=([0-9.]+)_batchsize=(\d+)_"
r"epochs=(\d+)_cosanneal=(True|False)"
)
match = re.match(pattern, ckpt_name.split(".pth")[0])
if not match:
raise ValueError("Checkpoint filename does not match expected format")
(rank, lr, batch_size, epochs, cos_annealing) = match.groups()
return {
"rank": int(rank),
"lr": float(lr),
"batch_size": int(batch_size),
"epochs": int(epochs),
"cos_annealing": cos_annealing == "True"
}