-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathextract_model_meta.py
More file actions
60 lines (51 loc) · 2.06 KB
/
extract_model_meta.py
File metadata and controls
60 lines (51 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json
import pandas as pd
from gluonts.model.deepar import DeepAREstimator
from gluonts.model.deepstate import DeepStateEstimator
from gluonts.model.deep_factor import DeepFactorEstimator
from gluonts.model.renewal import DeepRenewalProcessEstimator
from gluonts.model.gp_forecaster import GaussianProcessEstimator
from gluonts.model.seq2seq import MQCNNEstimator
from gluonts.model.seq2seq import MQRNNEstimator
from gluonts.model.n_beats import NBEATSEstimator
from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.model.transformer import TransformerEstimator
from gluonts.model.wavenet import WaveNetEstimator
from gluonts.model.rotbaum import TreeEstimator
from gluonts.model.tft import TemporalFusionTransformerEstimator
from gluonts.model.seasonal_naive import SeasonalNaiveEstimator
from methods import ARIMAWrapper
MODELS = {
"deepar": DeepAREstimator,
"deepstate": DeepStateEstimator,
"deepfactor": DeepFactorEstimator,
"deeprenewalprocesses": DeepRenewalProcessEstimator,
"gpforecaster": GaussianProcessEstimator,
"mqcnn": MQCNNEstimator,
"mqrnn": MQRNNEstimator,
"nbeats": NBEATSEstimator,
"rotbaum": TreeEstimator,
"temporalfusiontransformer": TemporalFusionTransformerEstimator,
"transformer": TransformerEstimator,
"wavenet": WaveNetEstimator,
"simplefeedforward": SimpleFeedForwardEstimator,
"naiveseasonal": SeasonalNaiveEstimator,
"arima": ARIMAWrapper
}
csv = pd.read_csv('gluonts_models.csv', sep=';')
meta = {}
for (_, row) in csv.iterrows():
info = row.to_dict()
for key, val in list(info.items()):
info[key.lower()] = val
del(info[key])
key = info['model + paper'].lower().replace('-', '').replace(' ', '')
if key not in MODELS:
info['module'] = None
info['class'] = None
else:
info['module'] = MODELS[key].__module__
info['class'] = MODELS[key].__name__
meta[key] = info
with open('meta_model2.json', 'w') as mf:
json.dump(meta, mf, indent=4)