diff --git a/tests/functional/codegen/types/numbers/test_constants.py b/tests/functional/codegen/types/numbers/test_constants.py index f29a20153b..308fdfa32f 100644 --- a/tests/functional/codegen/types/numbers/test_constants.py +++ b/tests/functional/codegen/types/numbers/test_constants.py @@ -5,7 +5,7 @@ from tests.utils import ZERO_ADDRESS, decimal_to_int from vyper.compiler import compile_code from vyper.exceptions import TypeMismatch -from vyper.utils import MemoryPositions +from vyper.utils import MemoryPositions, hex_to_int def search_for_sublist(ir, sublist): @@ -197,6 +197,23 @@ def test() -> Bytes[100]: assert c.test() == test_str +def test_constant_hex_int(get_contract): + test_value = "0xfa" + code = f""" +X: constant(uint8) = {test_value} + +@external +def test() -> uint8: + y: uint8 = X + + return y + """ + + c = get_contract(code) + + assert c.test() == hex_to_int(test_value) + + def test_constant_folds(experimental_codegen): some_prime = 10013677 code = f""" diff --git a/tests/functional/codegen/types/numbers/test_unsigned_ints.py b/tests/functional/codegen/types/numbers/test_unsigned_ints.py index 2bd3184ec0..59ed26e7e9 100644 --- a/tests/functional/codegen/types/numbers/test_unsigned_ints.py +++ b/tests/functional/codegen/types/numbers/test_unsigned_ints.py @@ -210,7 +210,8 @@ def foo(x: {typ}, y: {typ}) -> bool: @pytest.mark.parametrize("typ", types) -def test_uint_literal(get_contract, typ): +@pytest.mark.parametrize("is_hex_int", [True, False]) +def test_uint_literal(get_contract, typ, is_hex_int): lo, hi = typ.ast_bounds good_cases = [0, 1, 2, 3, hi // 2 - 1, hi // 2, hi // 2 + 1, hi - 1, hi] @@ -222,11 +223,21 @@ def test() -> {typ}: return o """ + def _to_hex_int(v): + n_nibbles = typ.bits // 4 + return "0x" + hex(v)[2:].rjust(n_nibbles, "0") + for val in good_cases: - c = get_contract(code_template.format(typ=typ, val=val)) + input_val = val + if is_hex_int: + n_nibbles = typ.bits // 4 + input_val = "0x" + hex(val)[2:].rjust(n_nibbles, "0") + c = get_contract(code_template.format(typ=typ, val=input_val)) assert c.test() == val for val in bad_cases: + if is_hex_int: + return exc = ( TypeMismatch if SizeLimits.MIN_INT256 <= val <= SizeLimits.MAX_UINT256 diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 974685f403..5e9132b643 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -31,6 +31,7 @@ SizeLimits, annotate_source_code, evm_div, + hex_to_int, quantize, sha256sum, ) @@ -883,6 +884,13 @@ def bytes_value(self): """ return bytes.fromhex(self.value.removeprefix("0x")) + @property + def int_value(self): + """ + This value as integer + """ + return hex_to_int(self.value) + class Str(Constant): __slots__ = () diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 783764271d..d4c7c60b95 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -120,6 +120,8 @@ class Decimal(Num): ... class Hex(Num): @property def n_bytes(self): ... + @property + def int_value(self): ... class Str(Constant): ... class Bytes(Constant): ... diff --git a/vyper/builtins/_signatures.py b/vyper/builtins/_signatures.py index d012e4a1cf..b61b5e9dfb 100644 --- a/vyper/builtins/_signatures.py +++ b/vyper/builtins/_signatures.py @@ -96,7 +96,8 @@ def modifiability(self): # helper function to deal with TYPE_Ts def _validate_single(self, arg: vy_ast.VyperNode, expected_type: VyperType) -> None: - if TYPE_T.any().compare_type(expected_type): + constant_node = arg if isinstance(arg, vy_ast.Constant) else None + if TYPE_T.any().compare_type(expected_type, constant_node): # try to parse the type - call type_from_annotation # for its side effects (will throw if is not a type) type_from_annotation(arg) diff --git a/vyper/codegen/expr.py b/vyper/codegen/expr.py index 3a09bbe6c0..126494a781 100644 --- a/vyper/codegen/expr.py +++ b/vyper/codegen/expr.py @@ -16,6 +16,7 @@ is_array_like, is_bytes_m_type, is_flag_type, + is_integer_type, is_numeric_type, is_tuple_like, make_setter, @@ -133,6 +134,13 @@ def parse_Hex(self): return IRnode.from_list(val, typ=t) + elif is_integer_type(t): + n_bits = n_bytes * 8 + assert t.bits <= n_bits + + val = self.expr.int_value + return IRnode.from_list(val, typ=t) + # String literals def parse_Str(self): bytez, bytez_length = string_to_bytes(self.expr.value) diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index a31ce7acc1..0e347755a8 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -595,7 +595,8 @@ def validate_expected_type(node, expected_type): return else: for given, expected in itertools.product(given_types, expected_type): - if expected.compare_type(given): + constant_node = node if isinstance(node, vy_ast.Constant) else None + if expected.compare_type(given, constant_node): return # validation failed, prepare a meaningful error message diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index aca37b33a3..b3320a9d2c 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -25,7 +25,7 @@ def __repr__(self): def __init__(self, type_): self.type_ = type_ - def compare_type(self, other): + def compare_type(self, other, is_constant): if isinstance(other, self.type_): return True # compare two GenericTypeAcceptors -- they are the same if the base @@ -290,7 +290,9 @@ def validate_literal(self, node: vy_ast.Constant) -> None: def validate_index_type(self, node: vy_ast.Subscript) -> None: raise StructureException(f"Not an indexable type: '{self}'", node) - def compare_type(self, other: "VyperType") -> bool: + def compare_type( + self, other: "VyperType", constant_node: Optional[vy_ast.Constant] = None + ) -> bool: """ Compare this type object against another type object. diff --git a/vyper/semantics/types/bytestrings.py b/vyper/semantics/types/bytestrings.py index 02e3bb213f..8586e05391 100644 --- a/vyper/semantics/types/bytestrings.py +++ b/vyper/semantics/types/bytestrings.py @@ -103,7 +103,7 @@ def set_min_length(self, min_length): raise CompilerPanic("Cannot reduce the min_length of ArrayValueType") self._min_length = min_length - def compare_type(self, other): + def compare_type(self, other, is_constant): if not super().compare_type(other): return False diff --git a/vyper/semantics/types/primitives.py b/vyper/semantics/types/primitives.py index 5c0362e662..6eab32e3f6 100644 --- a/vyper/semantics/types/primitives.py +++ b/vyper/semantics/types/primitives.py @@ -2,7 +2,7 @@ from decimal import Decimal from functools import cached_property -from typing import Any, Tuple, Union +from typing import Any, Optional, Tuple, Union from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABI_Bool, ABI_BytesM, ABI_GIntM, ABIType @@ -100,7 +100,9 @@ def validate_literal(self, node: vy_ast.Constant) -> None: if nibbles not in (nibbles.lower(), nibbles.upper()): raise InvalidLiteral(f"Cannot mix uppercase and lowercase for {self} literal", node) - def compare_type(self, other: VyperType) -> bool: + def compare_type( + self, other: VyperType, constant_node: Optional[vy_ast.Constant] = None + ) -> bool: if not super().compare_type(other): return False assert isinstance(other, BytesM_T) @@ -291,7 +293,18 @@ def all(cls) -> Tuple["IntegerT", ...]: def abi_type(self) -> ABIType: return ABI_GIntM(self.bits, self.is_signed) - def compare_type(self, other: VyperType) -> bool: + def compare_type( + self, other: VyperType, constant_node: Optional[vy_ast.Constant] = None + ) -> bool: + # handle hex integers + if ( + not self.is_signed + and isinstance(other, BytesM_T) + and isinstance(constant_node, vy_ast.Hex) + ): + lo, hi = self.ast_bounds + return lo <= constant_node.int_value <= hi + # this function is performance sensitive # originally: # if not super().compare_type(other): @@ -414,6 +427,6 @@ def validate_literal(self, node: vy_ast.Constant) -> None: class SelfT(AddressT): _id = "self" - def compare_type(self, other): + def compare_type(self, other, is_constant): # compares true to AddressT return isinstance(other, type(self)) or isinstance(self, type(other)) diff --git a/vyper/semantics/types/subscriptable.py b/vyper/semantics/types/subscriptable.py index 4068d815d2..546c593749 100644 --- a/vyper/semantics/types/subscriptable.py +++ b/vyper/semantics/types/subscriptable.py @@ -58,7 +58,7 @@ def __repr__(self): return f"HashMap[{self.key_type}, {self.value_type}]" # TODO not sure this is used? - def compare_type(self, other): + def compare_type(self, other, is_constant): return ( super().compare_type(other) and self.key_type == other.key_type @@ -196,7 +196,7 @@ def subtype(self): def get_subscripted_type(self, node): return self.value_type - def compare_type(self, other): + def compare_type(self, other, is_constant): if not isinstance(self, type(other)): return False if self.length != other.length: @@ -273,7 +273,7 @@ def size_in_bytes(self): # one length word + size of the array items return 32 + self.value_type.size_in_bytes * self.length - def compare_type(self, other): + def compare_type(self, other, is_constant): # TODO allow static array to be assigned to dyn array? # if not isinstance(other, (DArrayT, SArrayT)): if not isinstance(self, type(other)): @@ -384,7 +384,7 @@ def get_subscripted_type(self, node): node = node.reduced() return self.member_types[node.value] - def compare_type(self, other): + def compare_type(self, other, is_constant): if not isinstance(self, type(other)): return False if self.length != other.length: diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index 73fa4878c7..33a9291d96 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -34,7 +34,7 @@ def __init__(self, members=None): def __eq__(self, other): return self is other - def compare_type(self, other): + def compare_type(self, other, is_constant): # object exact comparison is a bit tricky here since we have # to be careful to construct any given user type exactly # only one time. however, the alternative requires reasoning