diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 497394a99..72deae8cc 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -79,8 +79,11 @@ def __init__( self, type: InteractionType | str | None = InteractionType.DETERMINISTIC ) -> None: super().__init__() - if isinstance(type, str): - type = InteractionType(type.lower()) + if not isinstance(type, InteractionType) and type is not None: + if isinstance(type, str): + type = InteractionType(type.lower()) + else: + raise ValueError(f"{type} is not a valid InteractionType") self.type = type def clone(self) -> set_interaction_type: diff --git a/test/test_nn.py b/test/test_nn.py index e27cb9b9d..274d08b29 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -88,9 +88,8 @@ def test_from_str_correct_conversion(self, str_and_expected_type): @pytest.mark.parametrize("unsupported_type_str", ["foo"]) def test_from_str_correct_raise(self, unsupported_type_str): - with pytest.raises(ValueError) as err: + with pytest.raises(ValueError, match=" is not a valid InteractionType"): InteractionType.from_str(unsupported_type_str) - assert unsupported_type_str in str(err) and "is unsupported" in str(err) class TestTDModule: