diff --git a/pycardano/address.py b/pycardano/address.py index 3905b1c3..4d416f13 100644 --- a/pycardano/address.py +++ b/pycardano/address.py @@ -19,7 +19,7 @@ ) from pycardano.hash import VERIFICATION_KEY_HASH_SIZE, ScriptHash, VerificationKeyHash from pycardano.network import Network -from pycardano.serialization import CBORSerializable +from pycardano.serialization import CBORSerializable, limit_primitive_type __all__ = ["AddressType", "PointerAddress", "Address"] @@ -160,6 +160,7 @@ def to_primitive(self) -> bytes: return self.encode() @classmethod + @limit_primitive_type(bytes) def from_primitive(cls: Type[PointerAddress], value: bytes) -> PointerAddress: return cls.decode(value) @@ -339,6 +340,7 @@ def to_primitive(self) -> bytes: return bytes(self) @classmethod + @limit_primitive_type(bytes, str) def from_primitive(cls: Type[Address], value: Union[bytes, str]) -> Address: if isinstance(value, str): value = bytes(decode(value)) diff --git a/pycardano/hash.py b/pycardano/hash.py index 975da671..f8034fd5 100644 --- a/pycardano/hash.py +++ b/pycardano/hash.py @@ -2,7 +2,7 @@ from typing import Type, TypeVar, Union -from pycardano.serialization import CBORSerializable +from pycardano.serialization import CBORSerializable, limit_primitive_type __all__ = [ "VERIFICATION_KEY_HASH_SIZE", @@ -67,6 +67,7 @@ def to_primitive(self) -> bytes: return self.payload @classmethod + @limit_primitive_type(bytes, str) def from_primitive(cls: Type[T], value: Union[bytes, str]) -> T: if isinstance(value, str): value = bytes.fromhex(value) diff --git a/pycardano/key.py b/pycardano/key.py index 4b771423..78b9d8be 100644 --- a/pycardano/key.py +++ b/pycardano/key.py @@ -14,7 +14,7 @@ from pycardano.crypto.bip32 import BIP32ED25519PrivateKey, HDWallet from pycardano.exception import InvalidKeyTypeException from pycardano.hash import VERIFICATION_KEY_HASH_SIZE, VerificationKeyHash -from pycardano.serialization import CBORSerializable +from pycardano.serialization import CBORSerializable, limit_primitive_type __all__ = [ "Key", @@ -62,6 +62,7 @@ def to_primitive(self) -> bytes: return self.payload @classmethod + @limit_primitive_type(bytes) def from_primitive(cls: Type["Key"], value: bytes) -> Key: return cls(value) diff --git a/pycardano/metadata.py b/pycardano/metadata.py index d11ea642..1dc8c85a 100644 --- a/pycardano/metadata.py +++ b/pycardano/metadata.py @@ -16,6 +16,7 @@ DictCBORSerializable, MapCBORSerializable, Primitive, + limit_primitive_type, list_hook, ) @@ -91,6 +92,7 @@ def to_primitive(self) -> Primitive: return CBORTag(AlonzoMetadata.TAG, super(AlonzoMetadata, self).to_primitive()) @classmethod + @limit_primitive_type(CBORTag) def from_primitive(cls: Type[AlonzoMetadata], value: CBORTag) -> AlonzoMetadata: if not hasattr(value, "tag"): raise DeserializeException( diff --git a/pycardano/nativescript.py b/pycardano/nativescript.py index 4ea29ace..a05fc8eb 100644 --- a/pycardano/nativescript.py +++ b/pycardano/nativescript.py @@ -10,7 +10,12 @@ from pycardano.exception import DeserializeException from pycardano.hash import SCRIPT_HASH_SIZE, ScriptHash, VerificationKeyHash -from pycardano.serialization import ArrayCBORSerializable, Primitive, list_hook +from pycardano.serialization import ( + ArrayCBORSerializable, + Primitive, + limit_primitive_type, + list_hook, +) from pycardano.types import JsonDict __all__ = [ @@ -30,22 +35,12 @@ class NativeScript(ArrayCBORSerializable): json_field: ClassVar[str] @classmethod + @limit_primitive_type(list) def from_primitive( - cls: Type[NativeScript], value: Primitive + cls: Type[NativeScript], value: list ) -> Union[ ScriptPubkey, ScriptAll, ScriptAny, ScriptNofK, InvalidBefore, InvalidHereAfter ]: - if not isinstance( - value, - ( - list, - tuple, - ), - ): - raise DeserializeException( - f"A list or a tuple is required for deserialization: {str(value)}" - ) - script_type: int = value[0] if script_type == ScriptPubkey._TYPE: return super(NativeScript, ScriptPubkey).from_primitive(value[1:]) diff --git a/pycardano/network.py b/pycardano/network.py index 55f1e265..38489cd0 100644 --- a/pycardano/network.py +++ b/pycardano/network.py @@ -5,8 +5,7 @@ from enum import Enum from typing import Type -from pycardano.exception import DeserializeException -from pycardano.serialization import CBORSerializable, Primitive +from pycardano.serialization import CBORSerializable, limit_primitive_type __all__ = ["Network"] @@ -23,9 +22,6 @@ def to_primitive(self) -> int: return self.value @classmethod - def from_primitive(cls: Type[Network], value: Primitive) -> Network: - if not isinstance(value, int): - raise DeserializeException( - f"An integer value is required for deserialization: {str(value)}" - ) + @limit_primitive_type(int) + def from_primitive(cls: Type[Network], value: int) -> Network: return cls(value) diff --git a/pycardano/plutus.py b/pycardano/plutus.py index 179f0936..a2e058fe 100644 --- a/pycardano/plutus.py +++ b/pycardano/plutus.py @@ -6,7 +6,7 @@ import json from dataclasses import dataclass, field, fields from enum import Enum -from typing import Any, ClassVar, List, Optional, Type, Union +from typing import Any, ClassVar, Optional, Type, Union import cbor2 from cbor2 import CBORTag @@ -21,9 +21,9 @@ CBORSerializable, DictCBORSerializable, IndefiniteList, - Primitive, RawCBOR, default_encoder, + limit_primitive_type, ) __all__ = [ @@ -66,6 +66,7 @@ def to_shallow_primitive(self) -> dict: return result @classmethod + @limit_primitive_type(dict) def from_primitive(cls: Type[CostModels], value: dict) -> CostModels: raise DeserializeException( "Deserialization of cost model is impossible, because some information is lost " @@ -480,11 +481,8 @@ def to_shallow_primitive(self) -> CBORTag: return CBORTag(102, [self.CONSTR_ID, primitives]) @classmethod + @limit_primitive_type(CBORTag) def from_primitive(cls: Type[PlutusData], value: CBORTag) -> PlutusData: - if not isinstance(value, CBORTag): - raise DeserializeException( - f"Unexpected type: {CBORTag}. Got {type(value)} instead." - ) if value.tag == 102: tag = value.value[0] if tag != cls.CONSTR_ID: @@ -643,6 +641,7 @@ def _dfs(obj): return _dfs(self.data) @classmethod + @limit_primitive_type(CBORTag) def from_primitive(cls: Type[RawPlutusData], value: CBORTag) -> RawPlutusData: return cls(value) @@ -675,6 +674,7 @@ def to_primitive(self) -> int: return self.value @classmethod + @limit_primitive_type(int) def from_primitive(cls: Type[RedeemerTag], value: int) -> RedeemerTag: return cls(value) @@ -704,7 +704,8 @@ class Redeemer(ArrayCBORSerializable): ex_units: ExecutionUnits = None @classmethod - def from_primitive(cls: Type[Redeemer], values: List[Primitive]) -> Redeemer: + @limit_primitive_type(list) + def from_primitive(cls: Type[Redeemer], values: list) -> Redeemer: if isinstance(values[2], CBORTag) and cls is Redeemer: values[2] = RawPlutusData.from_primitive(values[2]) redeemer = super(Redeemer, cls).from_primitive( diff --git a/pycardano/serialization.py b/pycardano/serialization.py index f0a4bf0b..c3743c95 100644 --- a/pycardano/serialization.py +++ b/pycardano/serialization.py @@ -8,6 +8,7 @@ from dataclasses import Field, dataclass, fields from datetime import datetime from decimal import Decimal +from functools import wraps from inspect import isclass from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints @@ -105,6 +106,31 @@ class RawCBOR: `Cbor2 encoder `_ directly. """ + +def limit_primitive_type(*allowed_types): + """ + A helper function to validate primitive type given to from_primitive class methods + + Not exposed to public by intention. + """ + + def decorator(func): + @wraps(func) + def wrapper(cls, value: Primitive): + if not isinstance(value, allowed_types): + allowed_types_str = [ + allowed_type.__name__ for allowed_type in allowed_types + ] + raise DeserializeException( + f"{allowed_types_str} typed value is required for deserialization. Got {type(value)}: {value}" + ) + return func(cls, value) + + return wrapper + + return decorator + + CBORBase = TypeVar("CBORBase", bound="CBORSerializable") @@ -245,7 +271,7 @@ def to_validated_primitive(self) -> Primitive: return self.to_primitive() @classmethod - def from_primitive(cls: Type[CBORBase], value: Primitive) -> CBORBase: + def from_primitive(cls: Type[CBORBase], value: Any) -> CBORBase: """Turn a CBOR primitive to its original class type. Args: @@ -407,7 +433,7 @@ def _restore_dataclass_field( elif t in PRIMITIVE_TYPES and isinstance(v, t): return v raise DeserializeException( - f"Cannot deserialize object: \n{str(v)}\n in any valid type from {t_args}." + f"Cannot deserialize object: \n{v}\n in any valid type from {t_args}." ) return v @@ -494,7 +520,8 @@ def to_shallow_primitive(self) -> List[Primitive]: return primitives @classmethod - def from_primitive(cls: Type[ArrayBase], values: Primitive) -> ArrayBase: + @limit_primitive_type(list) + def from_primitive(cls: Type[ArrayBase], values: list) -> ArrayBase: """Restore a primitive value to its original class type. Args: @@ -508,10 +535,6 @@ def from_primitive(cls: Type[ArrayBase], values: Primitive) -> ArrayBase: DeserializeException: When the object could not be restored from primitives. """ all_fields = [f for f in fields(cls) if f.init] - if type(values) != list: - raise DeserializeException( - f"Expect input value to be a list, got a {type(values)} instead." - ) restored_vals = [] type_hints = get_type_hints(cls) @@ -606,7 +629,8 @@ def to_shallow_primitive(self) -> Primitive: return primitives @classmethod - def from_primitive(cls: Type[MapBase], values: Primitive) -> MapBase: + @limit_primitive_type(dict) + def from_primitive(cls: Type[MapBase], values: dict) -> MapBase: """Restore a primitive value to its original class type. Args: @@ -620,10 +644,6 @@ def from_primitive(cls: Type[MapBase], values: Primitive) -> MapBase: :class:`pycardano.exception.DeserializeException`: When the object could not be restored from primitives. """ all_fields = {f.metadata.get("key", f.name): f for f in fields(cls) if f.init} - if type(values) != dict: - raise DeserializeException( - f"Expect input value to be a dict, got a {type(values)} instead." - ) kwargs = {} type_hints = get_type_hints(cls) @@ -725,7 +745,8 @@ def _get_sortable_val(key): return dict(sorted(self.data.items(), key=lambda x: _get_sortable_val(x[0]))) @classmethod - def from_primitive(cls: Type[DictBase], value: Primitive) -> DictBase: + @limit_primitive_type(dict) + def from_primitive(cls: Type[DictBase], value: dict) -> DictBase: """Restore a primitive value to its original class type. Args: @@ -739,11 +760,7 @@ def from_primitive(cls: Type[DictBase], value: Primitive) -> DictBase: DeserializeException: When the object could not be restored from primitives. """ if not value: - raise DeserializeException(f"Cannot accept empty value {str(value)}.") - if not isinstance(value, dict): - raise DeserializeException( - f"A dictionary value is required for deserialization: {str(value)}" - ) + raise DeserializeException(f"Cannot accept empty value {value}.") restored = cls() for k, v in value.items(): diff --git a/pyproject.toml b/pyproject.toml index 06a17a79..8d37a2e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,11 +63,11 @@ profile = "black" [tool.mypy] ignore_missing_imports = true +disable_error_code = ["str-bytes-safe"] python_version = 3.7 exclude = [ '^pycardano/cip/cip8.py$', '^pycardano/crypto/bech32.py$', - '^pycardano/address.py$', '^pycardano/certificate.py$', '^pycardano/coinselection.py$', '^pycardano/exception.py$', diff --git a/test/pycardano/test_address.py b/test/pycardano/test_address.py index ea1c2b54..d239b1cf 100644 --- a/test/pycardano/test_address.py +++ b/test/pycardano/test_address.py @@ -1,4 +1,7 @@ -from pycardano.address import Address +from unittest import TestCase + +from pycardano.address import Address, PointerAddress +from pycardano.exception import DeserializeException from pycardano.key import PaymentVerificationKey from pycardano.network import Network @@ -15,3 +18,27 @@ def test_payment_addr(): Address(vk.hash(), network=Network.TESTNET).encode() == "addr_test1vr2p8st5t5cxqglyjky7vk98k7jtfhdpvhl4e97cezuhn0cqcexl7" ) + + +class PointerAddressTest(TestCase): + def test_from_primitive_invalid_value(self): + with self.assertRaises(DeserializeException): + PointerAddress.from_primitive(1) + + with self.assertRaises(DeserializeException): + PointerAddress.from_primitive([]) + + with self.assertRaises(DeserializeException): + PointerAddress.from_primitive({}) + + +class AddressTest(TestCase): + def test_from_primitive_invalid_value(self): + with self.assertRaises(DeserializeException): + Address.from_primitive(1) + + with self.assertRaises(DeserializeException): + Address.from_primitive([]) + + with self.assertRaises(DeserializeException): + Address.from_primitive({}) diff --git a/test/pycardano/test_serialization.py b/test/pycardano/test_serialization.py index b496c99c..d406de75 100644 --- a/test/pycardano/test_serialization.py +++ b/test/pycardano/test_serialization.py @@ -1,13 +1,40 @@ from dataclasses import dataclass, field from test.pycardano.util import check_two_way_cbor +import pytest + +from pycardano.exception import DeserializeException from pycardano.serialization import ( ArrayCBORSerializable, + CBORSerializable, DictCBORSerializable, MapCBORSerializable, + limit_primitive_type, ) +@pytest.mark.single +def test_limit_primitive_type(): + class MockClass(CBORSerializable): + @classmethod + def from_primitive(*args): + return + + wrapped = limit_primitive_type(int, str, bytes, list, dict, tuple, dict)( + MockClass.from_primitive + ) + wrapped(MockClass, 1) + wrapped(MockClass, "") + wrapped(MockClass, b"") + wrapped(MockClass, []) + wrapped(MockClass, tuple()) + wrapped(MockClass, {}) + + wrapped = limit_primitive_type(int)(MockClass.from_primitive) + with pytest.raises(DeserializeException): + wrapped(MockClass, "") + + def test_array_cbor_serializable(): @dataclass class Test1(ArrayCBORSerializable):