Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent 819c77f commit 8fe1b84
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
6 changes: 2 additions & 4 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,11 @@ def __init__(
self, type: InteractionType | str | None = InteractionType.DETERMINISTIC
) -> None:
super().__init__()
if not isinstance(type, InteractionType):
if not isinstance(type, InteractionType) and type is not None:
if isinstance(type, str):
type = InteractionType(type.lower())
else:
raise TypeError(
"Invalid type: only str or InteractionType are accepted."
)
raise ValueError(f"{type} is not a valid InteractionType")
self.type = type

def clone(self) -> set_interaction_type:
Expand Down
3 changes: 1 addition & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ def test_from_str_correct_conversion(self, str_and_expected_type):

@pytest.mark.parametrize("unsupported_type_str", ["foo"])
def test_from_str_correct_raise(self, unsupported_type_str):
with pytest.raises(ValueError) as err:
with pytest.raises(ValueError, match=" is not a valid InteractionType"):
InteractionType.from_str(unsupported_type_str)
assert unsupported_type_str in str(err) and "is unsupported" in str(err)


class TestTDModule:
Expand Down

0 comments on commit 8fe1b84

Please sign in to comment.