From 030e6294c136143fd4c23a102f318bc78c596406 Mon Sep 17 00:00:00 2001 From: zerlinwang Date: Tue, 29 Nov 2022 10:12:02 +0800 Subject: [PATCH] fix(wzl): fix shape error in vdn mix --- offpolicy/algorithms/vdn/algorithm/vdn_mixer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/offpolicy/algorithms/vdn/algorithm/vdn_mixer.py b/offpolicy/algorithms/vdn/algorithm/vdn_mixer.py index 91e6ee0..0f50117 100644 --- a/offpolicy/algorithms/vdn/algorithm/vdn_mixer.py +++ b/offpolicy/algorithms/vdn/algorithm/vdn_mixer.py @@ -36,5 +36,6 @@ def forward(self, agent_q_inps, states): if type(agent_q_inps) == np.ndarray: agent_q_inps = torch.FloatTensor(agent_q_inps) - - return agent_q_inps.sum(dim=-1).view(-1, 1, 1) + + batch_size = agent_q_inps.size(1) + return agent_q_inps.sum(dim=-1).view(-1, batch_size, 1, 1)