-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
97 lines (78 loc) · 3.19 KB
/
dataset.py
File metadata and controls
97 lines (78 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import os
import pandas as pd
from torchvision.datasets.folder import default_loader
class CUBDataSet(torch.utils.data.Dataset):
def __init__(self,root, train=True,transform = None):
img_folder = os.path.join(root, "images")
img_paths = pd.read_csv(os.path.join(root, "images.txt"), sep=" ", header=None, names=['idx', 'path'])
img_labels = pd.read_csv(os.path.join(root, "image_class_labels.txt"), sep=" ", header=None,
names=['idx', 'label'])
train_test_split = pd.read_csv(os.path.join(root, "train_test_split.txt"), sep=" ", header=None,
names=['idx', 'train_flag'])
data = pd.concat([img_paths, img_labels, train_test_split], axis=1)
data = data[data['train_flag'] == train]
data['label'] = data['label'] - 1
imgs = data.reset_index(drop=True)
if len(imgs) == 0:
raise (RuntimeError("no csv file"))
self.transform = transform
self.root = img_folder
self.imgs = imgs
def __getitem__(self, index):
item = self.imgs.iloc[index]
file_path = item['path']
target = item['label']
img = default_loader(os.path.join(self.root, file_path))
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.imgs)
class CUB200Pair(CUBDataSet):
"""CUB200 Dataset.
"""
def __getitem__(self, index):
item = self.imgs.iloc[index]
file_path = item['path']
img = default_loader(os.path.join(self.root, file_path))
if self.transform is not None:
im_1 = self.transform(img)
im_2 = self.transform(img)
return im_1, im_2
class MyDataSet(torch.utils.data.Dataset):
def __init__(self,img_root, transform = None):
file_paths = []
labels_list = []
labels = {}
label_counter = 0
for root, dirs, files in os.walk(img_root):
for dir_name in dirs:
subdir_path = os.path.join(root, dir_name)
if dir_name not in labels:
labels[dir_name] = label_counter
label_counter += 1
for file_name in os.listdir(subdir_path):
file_path = os.path.join(subdir_path, file_name)
file_paths.append(file_path)
labels_list.append(labels[dir_name])
self.imgs = file_paths
self.labels = labels_list
self.transform = transform
def __getitem__(self, index):
img_path = self.imgs[index]
target = self.labels[index]
img = default_loader(img_path)
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.imgs)
class MyPair(MyDataSet):
def __getitem__(self, index):
img_path = self.imgs[index]
img = default_loader(img_path)
if self.transform is not None:
im_1 = self.transform(img)
im_2 = self.transform(img)
return im_1, im_2