Skip to content

Handle option masking correctly when computing TD targets. #51

@Ueva

Description

@Ueva

# Compute TD targets.
next_state_values = torch.zeros(self.batch_size, dtype=torch.float32)
with torch.no_grad():
next_state_values[non_terminal_mask] = self.target(non_terminal_next_states).max(1).values
targets = reward_batch + self.gamma * next_state_values

Currently, the value of masked options is not ignored when computing TD targets for Macro-DQN updates.

We need a function that produces an option mask for given states, which can be used here and elsewhere where option masking is needed. Ideally, this should be able to take batches of states and produce batches of option masks. This function will be called potentially many times per time step, so it should be performant.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions