diff --git a/src/gluonts/core/serde/_base.py b/src/gluonts/core/serde/_base.py index 9efee3aab8..97de04d65d 100644 --- a/src/gluonts/core/serde/_base.py +++ b/src/gluonts/core/serde/_base.py @@ -287,6 +287,15 @@ def encode_partial(v: partial) -> Any: } +decode_disallow = [ + eval, + exec, + compile, + open, + input, +] + + def decode(r: Any) -> Any: """ Decodes a value from an intermediate representation `r`. @@ -312,7 +321,10 @@ def decode(r: Any) -> Any: kind = r["__kind__"] cls = cast(Any, locate(r["class"])) - assert cls is not None, f"Can not locate {r['class']}." + if cls is None: + raise ValueError(f"Cannot locate {r['class']}.") + if cls in decode_disallow: + raise ValueError(f"{r['class']} cannot be run.") if kind == Kind.Type: return cls diff --git a/test/core/test_serde.py b/test/core/test_serde.py index 34ecdcb9ff..b651874d1c 100644 --- a/test/core/test_serde.py +++ b/test/core/test_serde.py @@ -142,3 +142,26 @@ def test_serde_method(): def test_np_str_dtype(): a = np.array(["foo"]) serde.decode(serde.encode(a.dtype)) == a.dtype + + +@pytest.mark.parametrize( + "obj", + [ + {"__kind__": 42, "class": cls_str} + for cls_str in [ + "builtins.eval", + "builtins.exec", + "builtins.compile", + "builtins.open", + "builtins.input", + "eval", + "exec", + "compile", + "open", + "input", + ] + ], +) +def test_decode_disallow(obj): + with pytest.raises(ValueError): + serde.decode(obj)