diff --git a/offpolicy/algorithms/vdn/algorithm/vdn_mixer.py b/offpolicy/algorithms/vdn/algorithm/vdn_mixer.py index 91e6ee0..0f50117 100644 --- a/offpolicy/algorithms/vdn/algorithm/vdn_mixer.py +++ b/offpolicy/algorithms/vdn/algorithm/vdn_mixer.py @@ -36,5 +36,6 @@ def forward(self, agent_q_inps, states): if type(agent_q_inps) == np.ndarray: agent_q_inps = torch.FloatTensor(agent_q_inps) - - return agent_q_inps.sum(dim=-1).view(-1, 1, 1) + + batch_size = agent_q_inps.size(1) + return agent_q_inps.sum(dim=-1).view(-1, batch_size, 1, 1)