-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathmain.py
More file actions
129 lines (98 loc) · 3.67 KB
/
main.py
File metadata and controls
129 lines (98 loc) · 3.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
from pathlib import Path
import rich.logging
import torch
import hydra
import warnings
import logging
from tracklab.datastruct import TrackerState
from tracklab.pipeline import Pipeline
from tracklab.utils import monkeypatch_hydra, progress, wandb
from hydra.utils import instantiate
from omegaconf import OmegaConf
os.environ["HYDRA_FULL_ERROR"] = "1"
log = logging.getLogger(__name__)
warnings.filterwarnings("ignore")
@hydra.main(version_base=None, config_path="pkg://tracklab.configs", config_name="config")
def main(cfg):
device = init_environment(cfg)
# Instantiate all modules
tracking_dataset = instantiate(cfg.dataset)
evaluator = instantiate(cfg.eval, tracking_dataset=tracking_dataset)
log.info(f"{tracking_dataset}")
modules = []
if cfg.pipeline is not None:
for name in cfg.pipeline:
module = cfg.modules[name]
inst_module = instantiate(module, device=device, tracking_dataset=tracking_dataset)
modules.append(inst_module)
pipeline = Pipeline(models=modules)
# Train tracking modules
for module in modules:
if module.training_enabled:
module.train(tracking_dataset, pipeline, evaluator, OmegaConf.to_container(cfg.dataset, resolve=True))
# Test tracking
if cfg.test_tracking:
log.info(f"Starting tracking operation on {cfg.dataset.eval_set} set.")
# Init tracker state and tracking engine
tracking_set = tracking_dataset.sets[cfg.dataset.eval_set]
tracker_state = TrackerState(tracking_set, pipeline=pipeline, **cfg.state)
tracking_engine = instantiate(
cfg.engine,
modules=pipeline,
tracker_state=tracker_state,
)
# Run tracking and visualization
tracking_engine.track_dataset()
# Evaluation
evaluate(cfg, evaluator, tracker_state)
# Save tracker state
if tracker_state.save_file is not None:
log.info(f"Saved state at : {tracker_state.save_file.resolve()}")
close_environment()
return 0
def set_sharing_strategy():
torch.multiprocessing.set_sharing_strategy(
"file_system"
)
def init_environment(cfg):
# For Hydra and Slurm compatibility
progress.use_rich = cfg.use_rich
set_sharing_strategy() # Do not touch
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
if cfg.use_rich:
for handler in log.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
log.root.addHandler(rich.logging.RichHandler(level=logging.INFO))
else:
# TODO : Fix for mmcv fix. This should be done in a nicer way
for handler in log.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.INFO)
wandb.init(cfg)
log.info(f"Run directory: {Path().absolute()}")
log.info(f"Using device: '{device}'.")
if cfg.print_config:
log.info(OmegaConf.to_yaml(cfg))
return device
def close_environment():
wandb.finish()
def evaluate(cfg, evaluator, tracker_state):
if cfg.get("eval_tracking", True): # and cfg.dataset.nframes == -1:
log.info("Starting evaluation.")
evaluator.run(tracker_state)
elif not cfg.get("eval_tracking", True):
log.warning("Skipping evaluation because 'eval_tracking' was set to False.")
else:
log.warning(
"Skipping evaluation because only part of video was tracked (i.e. 'cfg.dataset.nframes' was not set "
"to -1)"
)
if __name__ == "__main__":
main()