Skip to content

Improve address type hint #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pycardano/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from pycardano.hash import VERIFICATION_KEY_HASH_SIZE, ScriptHash, VerificationKeyHash
from pycardano.network import Network
from pycardano.serialization import CBORSerializable
from pycardano.serialization import CBORSerializable, limit_primitive_type

__all__ = ["AddressType", "PointerAddress", "Address"]

Expand Down Expand Up @@ -160,6 +160,7 @@ def to_primitive(self) -> bytes:
return self.encode()

@classmethod
@limit_primitive_type(bytes)
def from_primitive(cls: Type[PointerAddress], value: bytes) -> PointerAddress:
return cls.decode(value)

Expand Down Expand Up @@ -339,6 +340,7 @@ def to_primitive(self) -> bytes:
return bytes(self)

@classmethod
@limit_primitive_type(bytes, str)
def from_primitive(cls: Type[Address], value: Union[bytes, str]) -> Address:
if isinstance(value, str):
value = bytes(decode(value))
Expand Down
3 changes: 2 additions & 1 deletion pycardano/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Type, TypeVar, Union

from pycardano.serialization import CBORSerializable
from pycardano.serialization import CBORSerializable, limit_primitive_type

__all__ = [
"VERIFICATION_KEY_HASH_SIZE",
Expand Down Expand Up @@ -67,6 +67,7 @@ def to_primitive(self) -> bytes:
return self.payload

@classmethod
@limit_primitive_type(bytes, str)
def from_primitive(cls: Type[T], value: Union[bytes, str]) -> T:
if isinstance(value, str):
value = bytes.fromhex(value)
Expand Down
3 changes: 2 additions & 1 deletion pycardano/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pycardano.crypto.bip32 import BIP32ED25519PrivateKey, HDWallet
from pycardano.exception import InvalidKeyTypeException
from pycardano.hash import VERIFICATION_KEY_HASH_SIZE, VerificationKeyHash
from pycardano.serialization import CBORSerializable
from pycardano.serialization import CBORSerializable, limit_primitive_type

__all__ = [
"Key",
Expand Down Expand Up @@ -62,6 +62,7 @@ def to_primitive(self) -> bytes:
return self.payload

@classmethod
@limit_primitive_type(bytes)
def from_primitive(cls: Type["Key"], value: bytes) -> Key:
return cls(value)

Expand Down
2 changes: 2 additions & 0 deletions pycardano/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DictCBORSerializable,
MapCBORSerializable,
Primitive,
limit_primitive_type,
list_hook,
)

Expand Down Expand Up @@ -91,6 +92,7 @@ def to_primitive(self) -> Primitive:
return CBORTag(AlonzoMetadata.TAG, super(AlonzoMetadata, self).to_primitive())

@classmethod
@limit_primitive_type(CBORTag)
def from_primitive(cls: Type[AlonzoMetadata], value: CBORTag) -> AlonzoMetadata:
if not hasattr(value, "tag"):
raise DeserializeException(
Expand Down
21 changes: 8 additions & 13 deletions pycardano/nativescript.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

from pycardano.exception import DeserializeException
from pycardano.hash import SCRIPT_HASH_SIZE, ScriptHash, VerificationKeyHash
from pycardano.serialization import ArrayCBORSerializable, Primitive, list_hook
from pycardano.serialization import (
ArrayCBORSerializable,
Primitive,
limit_primitive_type,
list_hook,
)
from pycardano.types import JsonDict

__all__ = [
Expand All @@ -30,22 +35,12 @@ class NativeScript(ArrayCBORSerializable):
json_field: ClassVar[str]

@classmethod
@limit_primitive_type(list)
def from_primitive(
cls: Type[NativeScript], value: Primitive
cls: Type[NativeScript], value: list
) -> Union[
ScriptPubkey, ScriptAll, ScriptAny, ScriptNofK, InvalidBefore, InvalidHereAfter
]:
if not isinstance(
value,
(
list,
tuple,
),
):
raise DeserializeException(
f"A list or a tuple is required for deserialization: {str(value)}"
)

script_type: int = value[0]
if script_type == ScriptPubkey._TYPE:
return super(NativeScript, ScriptPubkey).from_primitive(value[1:])
Expand Down
10 changes: 3 additions & 7 deletions pycardano/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from enum import Enum
from typing import Type

from pycardano.exception import DeserializeException
from pycardano.serialization import CBORSerializable, Primitive
from pycardano.serialization import CBORSerializable, limit_primitive_type

__all__ = ["Network"]

Expand All @@ -23,9 +22,6 @@ def to_primitive(self) -> int:
return self.value

@classmethod
def from_primitive(cls: Type[Network], value: Primitive) -> Network:
if not isinstance(value, int):
raise DeserializeException(
f"An integer value is required for deserialization: {str(value)}"
)
@limit_primitive_type(int)
def from_primitive(cls: Type[Network], value: int) -> Network:
return cls(value)
15 changes: 8 additions & 7 deletions pycardano/plutus.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import Any, ClassVar, List, Optional, Type, Union
from typing import Any, ClassVar, Optional, Type, Union

import cbor2
from cbor2 import CBORTag
Expand All @@ -21,9 +21,9 @@
CBORSerializable,
DictCBORSerializable,
IndefiniteList,
Primitive,
RawCBOR,
default_encoder,
limit_primitive_type,
)

__all__ = [
Expand Down Expand Up @@ -66,6 +66,7 @@ def to_shallow_primitive(self) -> dict:
return result

@classmethod
@limit_primitive_type(dict)
def from_primitive(cls: Type[CostModels], value: dict) -> CostModels:
raise DeserializeException(
"Deserialization of cost model is impossible, because some information is lost "
Expand Down Expand Up @@ -480,11 +481,8 @@ def to_shallow_primitive(self) -> CBORTag:
return CBORTag(102, [self.CONSTR_ID, primitives])

@classmethod
@limit_primitive_type(CBORTag)
def from_primitive(cls: Type[PlutusData], value: CBORTag) -> PlutusData:
if not isinstance(value, CBORTag):
raise DeserializeException(
f"Unexpected type: {CBORTag}. Got {type(value)} instead."
)
if value.tag == 102:
tag = value.value[0]
if tag != cls.CONSTR_ID:
Expand Down Expand Up @@ -643,6 +641,7 @@ def _dfs(obj):
return _dfs(self.data)

@classmethod
@limit_primitive_type(CBORTag)
def from_primitive(cls: Type[RawPlutusData], value: CBORTag) -> RawPlutusData:
return cls(value)

Expand Down Expand Up @@ -675,6 +674,7 @@ def to_primitive(self) -> int:
return self.value

@classmethod
@limit_primitive_type(int)
def from_primitive(cls: Type[RedeemerTag], value: int) -> RedeemerTag:
return cls(value)

Expand Down Expand Up @@ -704,7 +704,8 @@ class Redeemer(ArrayCBORSerializable):
ex_units: ExecutionUnits = None

@classmethod
def from_primitive(cls: Type[Redeemer], values: List[Primitive]) -> Redeemer:
@limit_primitive_type(list)
def from_primitive(cls: Type[Redeemer], values: list) -> Redeemer:
if isinstance(values[2], CBORTag) and cls is Redeemer:
values[2] = RawPlutusData.from_primitive(values[2])
redeemer = super(Redeemer, cls).from_primitive(
Expand Down
53 changes: 35 additions & 18 deletions pycardano/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataclasses import Field, dataclass, fields
from datetime import datetime
from decimal import Decimal
from functools import wraps
from inspect import isclass
from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints

Expand Down Expand Up @@ -105,6 +106,31 @@ class RawCBOR:
`Cbor2 encoder <https://cbor2.readthedocs.io/en/latest/modules/encoder.html>`_ directly.
"""


def limit_primitive_type(*allowed_types):
"""
A helper function to validate primitive type given to from_primitive class methods

Not exposed to public by intention.
"""

def decorator(func):
@wraps(func)
def wrapper(cls, value: Primitive):
if not isinstance(value, allowed_types):
allowed_types_str = [
allowed_type.__name__ for allowed_type in allowed_types
]
raise DeserializeException(
f"{allowed_types_str} typed value is required for deserialization. Got {type(value)}: {value}"
)
return func(cls, value)

return wrapper

return decorator


CBORBase = TypeVar("CBORBase", bound="CBORSerializable")


Expand Down Expand Up @@ -245,7 +271,7 @@ def to_validated_primitive(self) -> Primitive:
return self.to_primitive()

@classmethod
def from_primitive(cls: Type[CBORBase], value: Primitive) -> CBORBase:
def from_primitive(cls: Type[CBORBase], value: Any) -> CBORBase:
"""Turn a CBOR primitive to its original class type.

Args:
Expand Down Expand Up @@ -407,7 +433,7 @@ def _restore_dataclass_field(
elif t in PRIMITIVE_TYPES and isinstance(v, t):
return v
raise DeserializeException(
f"Cannot deserialize object: \n{str(v)}\n in any valid type from {t_args}."
f"Cannot deserialize object: \n{v}\n in any valid type from {t_args}."
)
return v

Expand Down Expand Up @@ -494,7 +520,8 @@ def to_shallow_primitive(self) -> List[Primitive]:
return primitives

@classmethod
def from_primitive(cls: Type[ArrayBase], values: Primitive) -> ArrayBase:
@limit_primitive_type(list)
def from_primitive(cls: Type[ArrayBase], values: list) -> ArrayBase:
"""Restore a primitive value to its original class type.

Args:
Expand All @@ -508,10 +535,6 @@ def from_primitive(cls: Type[ArrayBase], values: Primitive) -> ArrayBase:
DeserializeException: When the object could not be restored from primitives.
"""
all_fields = [f for f in fields(cls) if f.init]
if type(values) != list:
raise DeserializeException(
f"Expect input value to be a list, got a {type(values)} instead."
)

restored_vals = []
type_hints = get_type_hints(cls)
Expand Down Expand Up @@ -606,7 +629,8 @@ def to_shallow_primitive(self) -> Primitive:
return primitives

@classmethod
def from_primitive(cls: Type[MapBase], values: Primitive) -> MapBase:
@limit_primitive_type(dict)
def from_primitive(cls: Type[MapBase], values: dict) -> MapBase:
"""Restore a primitive value to its original class type.

Args:
Expand All @@ -620,10 +644,6 @@ def from_primitive(cls: Type[MapBase], values: Primitive) -> MapBase:
:class:`pycardano.exception.DeserializeException`: When the object could not be restored from primitives.
"""
all_fields = {f.metadata.get("key", f.name): f for f in fields(cls) if f.init}
if type(values) != dict:
raise DeserializeException(
f"Expect input value to be a dict, got a {type(values)} instead."
)

kwargs = {}
type_hints = get_type_hints(cls)
Expand Down Expand Up @@ -725,7 +745,8 @@ def _get_sortable_val(key):
return dict(sorted(self.data.items(), key=lambda x: _get_sortable_val(x[0])))

@classmethod
def from_primitive(cls: Type[DictBase], value: Primitive) -> DictBase:
@limit_primitive_type(dict)
def from_primitive(cls: Type[DictBase], value: dict) -> DictBase:
"""Restore a primitive value to its original class type.

Args:
Expand All @@ -739,11 +760,7 @@ def from_primitive(cls: Type[DictBase], value: Primitive) -> DictBase:
DeserializeException: When the object could not be restored from primitives.
"""
if not value:
raise DeserializeException(f"Cannot accept empty value {str(value)}.")
if not isinstance(value, dict):
raise DeserializeException(
f"A dictionary value is required for deserialization: {str(value)}"
)
raise DeserializeException(f"Cannot accept empty value {value}.")

restored = cls()
for k, v in value.items():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ profile = "black"

[tool.mypy]
ignore_missing_imports = true
disable_error_code = ["str-bytes-safe"]
python_version = 3.7
exclude = [
'^pycardano/cip/cip8.py$',
'^pycardano/crypto/bech32.py$',
'^pycardano/address.py$',
'^pycardano/certificate.py$',
'^pycardano/coinselection.py$',
'^pycardano/exception.py$',
Expand Down
29 changes: 28 additions & 1 deletion test/pycardano/test_address.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pycardano.address import Address
from unittest import TestCase

from pycardano.address import Address, PointerAddress
from pycardano.exception import DeserializeException
from pycardano.key import PaymentVerificationKey
from pycardano.network import Network

Expand All @@ -15,3 +18,27 @@ def test_payment_addr():
Address(vk.hash(), network=Network.TESTNET).encode()
== "addr_test1vr2p8st5t5cxqglyjky7vk98k7jtfhdpvhl4e97cezuhn0cqcexl7"
)


class PointerAddressTest(TestCase):
def test_from_primitive_invalid_value(self):
with self.assertRaises(DeserializeException):
PointerAddress.from_primitive(1)

with self.assertRaises(DeserializeException):
PointerAddress.from_primitive([])

with self.assertRaises(DeserializeException):
PointerAddress.from_primitive({})


class AddressTest(TestCase):
def test_from_primitive_invalid_value(self):
with self.assertRaises(DeserializeException):
Address.from_primitive(1)

with self.assertRaises(DeserializeException):
Address.from_primitive([])

with self.assertRaises(DeserializeException):
Address.from_primitive({})
Loading