diff --git a/brownie/test/strategies.py b/brownie/test/strategies.py index 136fa0890..bb089e566 100644 --- a/brownie/test/strategies.py +++ b/brownie/test/strategies.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 -from typing import Any, Callable, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Iterable, Literal, Optional, Tuple, Union, overload from eth_abi.grammar import BasicType, TupleType, parse from hypothesis import strategies as st @@ -16,6 +16,76 @@ ArrayLengthType = Union[int, list, None] NumberType = Union[float, int, None] +EvmIntType = Literal[ + "int8", + "int16", + "int24", + "int32", + "int40", + "int48", + "int56", + "int64", + "int72", + "int80", + "int88", + "int96", + "int104", + "int112", + "int120", + "int128", + "int136", + "int144", + "int152", + "int160", + "int168", + "int176", + "int184", + "int192", + "int200", + "int208", + "int216", + "int224", + "int232", + "int240", + "int248", + "int256", +] + +EvmUintType = Literal[ + "uint8", + "uint16", + "uint24", + "uint32", + "uint40", + "uint48", + "uint56", + "uint64", + "uint72", + "uint80", + "uint88", + "uint96", + "uint104", + "uint112", + "uint120", + "uint128", + "uint136", + "uint144", + "uint152", + "uint160", + "uint168", + "uint176", + "uint184", + "uint192", + "uint200", + "uint208", + "uint216", + "uint224", + "uint232", + "uint240", + "uint248", + "uint256", +] + class _DeferredStrategyRepr(DeferredStrategy): def __init__(self, fn: Callable, repr_target: str) -> None: @@ -76,9 +146,9 @@ def _decimal_strategy( @_exclude_filter -def _address_strategy(length: Optional[int] = None) -> SearchStrategy: +def _address_strategy(length: Optional[int] = None, include: list = []) -> SearchStrategy: return _DeferredStrategyRepr( - lambda: st.sampled_from(list(network.accounts)[:length]), "accounts" + lambda: st.sampled_from(list(network.accounts)[:length] + include), "accounts" ) @@ -153,6 +223,22 @@ def _contract_deferred(name): return _DeferredStrategyRepr(lambda: _contract_deferred(contract_name), contract_name) +@overload +def strategy( + type_str: Literal["address"], + length: Optional[int] = None, + include: list = [], +) -> SearchStrategy: ... + + +@overload +def strategy( + type_str: Union[EvmIntType, EvmUintType], + min_value: Optional[int] = None, + max_value: Optional[int] = None, +) -> SearchStrategy: ... + + def strategy(type_str: str, **kwargs: Any) -> SearchStrategy: type_str = TYPE_STR_TRANSLATIONS.get(type_str, type_str) if type_str == "fixed168x10": diff --git a/setup.cfg b/setup.cfg index 6fed16208..7d815cafb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ current_version = 1.20.5 [flake8] exclude = tests/data/* max-line-length = 100 -ignore = E203,W503 +ignore = E203,E704,W503 [tool:isort] force_grid_wrap = 0