From 3d9e40891ffdd39d6a5bf56730d468ace142752f Mon Sep 17 00:00:00 2001 From: Shoufa Chen Date: Thu, 23 Sep 2021 23:05:21 +0800 Subject: [PATCH] fix: reduce gpu memory --- util/dist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/util/dist.py b/util/dist.py index 2b25a7a..5466b73 100644 --- a/util/dist.py +++ b/util/dist.py @@ -83,7 +83,7 @@ def all_gather(data): for size, tensor in zip(size_list, tensor_list): tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] buffer = io.BytesIO(tensor.cpu().numpy()) - obj = torch.load(buffer) + obj = torch.load(buffer, map_location=device) data_list.append(obj) return data_list