-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataloader.py
More file actions
24 lines (20 loc) · 903 Bytes
/
dataloader.py
File metadata and controls
24 lines (20 loc) · 903 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from torch.utils.data import DataLoader
from dataset.jsondataset import JsonDataset
from .dist import get_rank, get_world_size
def create_dataloader(config):
dataset_train = JsonDataset(config.train_file, config=config)
print("distributed debug message, RANK: {} and WORLD_SIZE: {}".format(get_rank(), get_world_size()))
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=True, num_replicas=get_world_size(), rank=get_rank()) if config.distributed else None
loader_train = DataLoader(
dataset_train,
batch_size=config.batch_size,
shuffle=(sampler is None),
num_workers=config.workers,
pin_memory=True,
drop_last=True,
sampler=sampler)
if get_rank() == 0:
total_image = len(dataset_train)
print("Total training images: ", total_image)
return dataset_train, loader_train