forked from hello-trouble/HardGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataloader.py
More file actions
26 lines (21 loc) · 919 Bytes
/
dataloader.py
File metadata and controls
26 lines (21 loc) · 919 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torchvision.transforms as tfs
import os
from PIL import Image
import numpy as np
from torch.utils import data
class Dataset(data.Dataset):
def __init__(self,path_root="./dataset/",mode="train"):
super(Dataset,self).__init__()
self.path_root=path_root+mode
self.cloud_images_dir=os.listdir(os.path.join(self.path_root,"cloud"))
self.cloud_images=[os.path.join(self.path_root,"cloud",img) for img in self.cloud_images_dir]
self.gt_images_dir=os.listdir(os.path.join(self.path_root,"label"))
self.gt_images=[os.path.join(self.path_root,"label",img) for img in self.gt_images_dir]
def __getitem__(self, item):
cloud=Image.open(self.cloud_images[item])
gt=Image.open(self.gt_images[item])
cloud=tfs.ToTensor()(cloud)
gt=tfs.ToTensor()(gt)
return cloud,gt
def __len__(self):
return len(self.cloud_images)