diff --git a/src/gfn/env.py b/src/gfn/env.py index 9b045ca3..510d3820 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -219,6 +219,10 @@ def _step( not_done_actions = actions[~new_sink_states_idx] new_not_done_states_tensor = self.step(not_done_states, not_done_actions) + if not isinstance(new_not_done_states_tensor, torch.Tensor): + raise Exception( + "User implemented env.step function *must* return a torch.Tensor!" + ) new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor