-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatasets.py
More file actions
26 lines (18 loc) · 852 Bytes
/
datasets.py
File metadata and controls
26 lines (18 loc) · 852 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from tensorflow.examples.tutorials.mnist import input_data
import model
class MNIST(model.TrainableData):
def __init__(self, batch_size, data_dir='/tmp/data/MNIST_data/'):
self.batch_size = batch_size
self.mnist = input_data.read_data_sets(data_dir, one_hot=True)
def data_generator(self):
while 1:
x_batch, y_batch = self.mnist.train.next_batch(batch_size=self.batch_size)
yield ([x_batch, y_batch], [y_batch, x_batch])
def get_next_test_batch(self):
return self.mnist.test.next_batch(batch_size=self.batch_size)
def get_next_train_batch(self):
return self.mnist.train.next_batch(batch_size=self.batch_size)
def get_all_test_data(self):
return [self.mnist.test.images, self.mnist.test.labels]
def get_data_element(self):
return self.mnist