|
# Compute TD targets. |
|
next_state_values = torch.zeros(self.batch_size, dtype=torch.float32) |
|
with torch.no_grad(): |
|
next_state_values[non_terminal_mask] = self.target(non_terminal_next_states).max(1).values |
|
targets = reward_batch + self.gamma * next_state_values |
Currently, the value of masked options is not ignored when computing TD targets for Macro-DQN updates.
We need a function that produces an option mask for given states, which can be used here and elsewhere where option masking is needed. Ideally, this should be able to take batches of states and produce batches of option masks. This function will be called potentially many times per time step, so it should be performant.
BaRL-SimpleOptions/simpleoptions/function_approximation/agents/macro_dqn_options_agent.py
Lines 121 to 125 in f82e945
Currently, the value of masked options is not ignored when computing TD targets for Macro-DQN updates.
We need a function that produces an option mask for given states, which can be used here and elsewhere where option masking is needed. Ideally, this should be able to take batches of states and produce batches of option masks. This function will be called potentially many times per time step, so it should be performant.