-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodify_mnist_digits.py
More file actions
49 lines (41 loc) · 1.24 KB
/
modify_mnist_digits.py
File metadata and controls
49 lines (41 loc) · 1.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# program to modify replace certain digits in the mnist data set for improved training
import struct
def convert(imgf, labelf, outf, n):
f = open(imgf, "rb")
o = open(outf, "w")
l = open(labelf, "rb")
f.read(16)
l.read(8)
images = []
for i in range(n):
image = [ord(l.read(1))]
for j in range(28*28):
image.append(ord(f.read(1)))
images.append(image)
for image in images:
o.write(",".join(str(pix) for pix in image)+"\n")
f.close()
o.close()
l.close()
def modify_labels(labelf, outf):
o = open(outf, "wb")
l = open(labelf, "rb")
#i = int.from_bytes(l.read(4), byteorder='big')
magicRaw= l.read(4)
magicN = struct.unpack('>i', magicRaw)[0]
numItemsRaw = l.read(4)
numItems = struct.unpack('>i', numItemsRaw)[0]
o.write(magicRaw)
o.write(numItemsRaw)
for i in range(numItems):
itemValueRaw= l.read(1)
itemValue = ord(itemValueRaw)
if itemValue == 9:
o.write(chr(itemValue))
else:
itemValue = 0
o.write(chr(itemValue))
o.close()
l.close()
modify_labels("./MNIST-data/t10k-labels-idx1-ubyte-orig","./MNIST-data/t10k-labels-idx1-ubyte-9to0")
modify_labels("./MNIST-data/train-labels-idx1-ubyte-orig","./MNIST-data/train-labels-idx1-ubyte-9to0")