diff --git a/util/dist.py b/util/dist.py index 2b25a7a..5466b73 100644 --- a/util/dist.py +++ b/util/dist.py @@ -83,7 +83,7 @@ def all_gather(data): for size, tensor in zip(size_list, tensor_list): tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] buffer = io.BytesIO(tensor.cpu().numpy()) - obj = torch.load(buffer) + obj = torch.load(buffer, map_location=device) data_list.append(obj) return data_list