-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathgenerate_mnist_data.py
More file actions
31 lines (24 loc) · 915 Bytes
/
generate_mnist_data.py
File metadata and controls
31 lines (24 loc) · 915 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
import torch
from scenarios.mnist_scenarios import MNISTScenarioZ, MNISTScenarioX, \
MNISTScenarioXZ
from scenarios.toy_scenarios import Standardizer
def create_dataset(scenario_class, dir):
# set random seed
seed = 527
np.random.seed(seed)
torch.manual_seed(seed)
# set up model classes, objective, and data scenario
num_train = 20000
num_dev = 10000
num_test = 10000
scenario = Standardizer(scenario_class(g_function="abs"))
scenario.setup(num_train=num_train, num_dev=num_dev, num_test=num_test)
scenario.info()
scenario.to_file(dir)
if __name__ == "__main__":
for scenario, path in [(MNISTScenarioX, "mnist_x"),
(MNISTScenarioZ, "mnist_z"),
(MNISTScenarioXZ, "mnist_xz")]:
print("Creating " + path + " ...")
create_dataset(scenario, "data/" + path + "/main")