-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtoy_data.py
More file actions
93 lines (73 loc) · 3.95 KB
/
toy_data.py
File metadata and controls
93 lines (73 loc) · 3.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from torch.utils.data import DataLoader
import os
import numpy as np
os.environ["MODEL_DIR"] = "./saved_models"
" Toy Dataset Taken From Original Dirichelet Flow Matching for DNA Paper"
class ToyDataset(torch.utils.data.IterableDataset):
def __init__(self, args, max_samples = 512000):
super().__init__()
self.num_cls = args.toy_num_cls
self.seq_len = args.toy_seq_len
self.alphabet_size = args.toy_simplex_dim
self.max_samples = max_samples # Add termination condition
self.probs = torch.softmax(torch.rand((self.num_cls, self.seq_len, self.alphabet_size)), dim=2)
self.class_probs = torch.ones(self.num_cls)
if self.num_cls > 1:
self.class_probs = self.class_probs * 1 / 2 / (self.num_cls - 1)
self.class_probs[0] = 1 / 2
assert self.class_probs.sum() == 1
distribution_dict = {'probs': self.probs, 'class_probs': self.class_probs}
torch.save(distribution_dict, os.path.join(os.environ["MODEL_DIR"], 'toy_distribution_dict.pt' ))
def __len__(self):
#for some reason they have it ridiculously high (100, 000, 000), return to this later if there is an issue
return self.max_samples # Matches actual dataset size
def __iter__(self):
# original function just a while true loop with no stopping condition lol
# this is a placeholder for now
count = 0
while count < self.max_samples: # Add termination condition
cls = np.random.choice(a=self.num_cls, size=1, p=self.class_probs.numpy())
seq = []
for i in range(self.seq_len):
seq.append(torch.multinomial(
self.probs[cls[0], i, :],
num_samples=1,
replacement=True
))
yield torch.stack(seq).squeeze(-1), torch.tensor(cls)
count += 1
class EnhancerDataset(torch.utils.data.Dataset):
def __init__(self, args, split='train'):
all_data = pickle.load(open(f'data/the_code/General/data/Deep{"MEL2" if args.mel_enhancer else "FlyBrain"}_data.pkl', 'rb'))
self.seqs = torch.argmax(torch.from_numpy(copy.deepcopy(all_data[f'{split}_data'])), dim=-1)
self.clss = torch.argmax(torch.from_numpy(copy.deepcopy(all_data[f'y_{split}'])), dim=-1)
self.num_cls = all_data[f'y_{split}'].shape[-1]
self.alphabet_size = 4
def __len__(self):
return len(self.seqs)
def __getitem__(self, idx):
return self.seqs[idx], self.clss[idx]
class TwoClassOverfitDataset(torch.utils.data.IterableDataset):
def __init__(self, args):
super().__init__()
self.seq_len = args.toy_seq_len
self.alphabet_size = args.toy_simplex_dim
self.num_cls = 2
if args.cls_ckpt is not None:
distribution_dict = torch.load(os.path.join(os.path.dirname(args.cls_ckpt), 'overfit_dataset.pt'))
self.data_class1 = distribution_dict['data_class1']
self.data_class2 = distribution_dict['data_class2']
else:
self.data_class1 = torch.stack([torch.from_numpy(np.random.choice(np.arange(self.alphabet_size), size=args.toy_seq_len, replace=True)) for _ in range(args.toy_num_seq)])
self.data_class2 = torch.stack([torch.from_numpy(np.random.choice(np.arange(self.alphabet_size), size=args.toy_seq_len, replace=True)) for _ in range(args.toy_num_seq)])
distribution_dict = {'data_class1': self.data_class1, 'data_class2': self.data_class2}
torch.save(distribution_dict, os.path.join(os.environ["MODEL_DIR"], 'overfit_dataset.pt'))
def __len__(self):
return 10000000000
def __iter__(self):
while True:
if np.random.rand() < 0.5:
yield self.data_class1[np.random.choice(np.arange(len(self.data_class1)))], torch.tensor([0])
else:
yield self.data_class2[np.random.choice(np.arange(len(self.data_class2)))], torch.tensor([1])