diff --git a/setup.py b/setup.py index ad54f7a8..47d2b789 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ install_requires=[ "asyncio==3.4.3", "pyyaml==6.0", - "redis==4.4.4", + "redis==5.0.3", "sgx.py==0.9dev2", "skale-contracts==1.0.1a5", "typing-extensions==4.9.0", diff --git a/skale/contracts/allocator/__init__.py b/skale/contracts/allocator/__init__.py index 5d14ee33..6365e525 100644 --- a/skale/contracts/allocator/__init__.py +++ b/skale/contracts/allocator/__init__.py @@ -2,3 +2,5 @@ from skale.contracts.allocator.escrow import Escrow from skale.contracts.allocator.allocator import Allocator + +__all__ = ['Allocator', 'Escrow'] diff --git a/skale/contracts/allocator/allocator.py b/skale/contracts/allocator/allocator.py index 37f64b1d..22496a2f 100644 --- a/skale/contracts/allocator/allocator.py +++ b/skale/contracts/allocator/allocator.py @@ -18,11 +18,24 @@ # along with SKALE.py. If not, see . """ SKALE Allocator Core Escrow methods """ -from enum import IntEnum - -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.exceptions import ContractLogicError -from skale.transactions.result import TxRes +from typing import Any, Dict, List + +from eth_typing import ChecksumAddress +from web3 import Web3 +from web3.contract.contract import ContractFunction +from web3.exceptions import ContractLogicError +from web3.types import Wei + +from skale.contracts.allocator_contract import AllocatorContract +from skale.contracts.base_contract import transaction_method +from skale.types.allocation import ( + BeneficiaryStatus, + BeneficiaryPlan, + Plan, + PlanId, + PlanWithId, + TimeUnit +) from skale.utils.helper import format_fields @@ -47,36 +60,25 @@ MAX_NUM_OF_BENEFICIARIES = 9999 -class TimeUnit(IntEnum): - DAY = 0 - MONTH = 1 - YEAR = 2 - - -class BeneficiaryStatus(IntEnum): - UNKNOWN = 0 - CONFIRMED = 1 - ACTIVE = 2 - TERMINATED = 3 - - -class Allocator(BaseContract): - def is_beneficiary_registered(self, beneficiary_address: str) -> bool: +class Allocator(AllocatorContract): + def is_beneficiary_registered(self, beneficiary_address: ChecksumAddress) -> bool: """Confirms whether the beneficiary is registered in a Plan. :returns: Boolean value :rtype: bool """ - return self.contract.functions.isBeneficiaryRegistered(beneficiary_address).call() + return bool(self.contract.functions.isBeneficiaryRegistered(beneficiary_address).call()) - def is_delegation_allowed(self, beneficiary_address: str) -> bool: - return self.contract.functions.isDelegationAllowed(beneficiary_address).call() + def is_delegation_allowed(self, beneficiary_address: ChecksumAddress) -> bool: + return bool(self.contract.functions.isDelegationAllowed(beneficiary_address).call()) - def is_vesting_active(self, beneficiary_address: str) -> bool: - return self.contract.functions.isVestingActive(beneficiary_address).call() + def is_vesting_active(self, beneficiary_address: ChecksumAddress) -> bool: + return bool(self.contract.functions.isVestingActive(beneficiary_address).call()) - def get_escrow_address(self, beneficiary_address: str) -> str: - return self.contract.functions.getEscrowAddress(beneficiary_address).call() + def get_escrow_address(self, beneficiary_address: ChecksumAddress) -> ChecksumAddress: + return Web3.to_checksum_address( + self.contract.functions.getEscrowAddress(beneficiary_address).call() + ) @transaction_method def add_plan( @@ -87,7 +89,7 @@ def add_plan( vesting_interval: int, can_delegate: bool, is_terminatable: bool - ) -> TxRes: + ) -> ContractFunction: return self.contract.functions.addPlan( vestingCliff=vesting_cliff, totalVestingDuration=total_vesting_duration, @@ -100,12 +102,12 @@ def add_plan( @transaction_method def connect_beneficiary_to_plan( self, - beneficiary_address: str, + beneficiary_address: ChecksumAddress, plan_id: int, start_month: int, full_amount: int, lockup_amount: int, - ) -> TxRes: + ) -> ContractFunction: return self.contract.functions.connectBeneficiaryToPlan( beneficiary=beneficiary_address, planId=plan_id, @@ -115,61 +117,102 @@ def connect_beneficiary_to_plan( ) @transaction_method - def start_vesting(self, beneficiary_address: str) -> TxRes: + def start_vesting(self, beneficiary_address: ChecksumAddress) -> ContractFunction: return self.contract.functions.startVesting(beneficiary_address) @transaction_method - def stop_vesting(self, beneficiary_address: str) -> TxRes: + def stop_vesting(self, beneficiary_address: ChecksumAddress) -> ContractFunction: return self.contract.functions.stopVesting(beneficiary_address) @transaction_method - def grant_role(self, role: bytes, address: str) -> TxRes: + def grant_role(self, role: bytes, address: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, address) def vesting_manager_role(self) -> bytes: - return self.contract.functions.VESTING_MANAGER_ROLE().call() + return bytes(self.contract.functions.VESTING_MANAGER_ROLE().call()) - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) - def __get_beneficiary_plan_params_raw(self, beneficiary_address: str): - return self.contract.functions.getBeneficiaryPlanParams(beneficiary_address).call() + def __get_beneficiary_plan_params_raw(self, beneficiary_address: ChecksumAddress) -> List[Any]: + return list(self.contract.functions.getBeneficiaryPlanParams(beneficiary_address).call()) @format_fields(BENEFICIARY_FIELDS) - def get_beneficiary_plan_params_dict(self, beneficiary_address: str) -> dict: + def get_beneficiary_plan_params_dict(self, beneficiary_address: ChecksumAddress) -> List[Any]: return self.__get_beneficiary_plan_params_raw(beneficiary_address) - def get_beneficiary_plan_params(self, beneficiary_address: str) -> dict: + def get_beneficiary_plan_params(self, beneficiary_address: ChecksumAddress) -> BeneficiaryPlan: plan_params = self.get_beneficiary_plan_params_dict(beneficiary_address) - plan_params['statusName'] = BeneficiaryStatus(plan_params['status']).name - return plan_params - - def __get_plan_raw(self, plan_id: int): - return self.contract.functions.getPlan(plan_id).call() + if plan_params is None: + raise ValueError('Plan for ', beneficiary_address, ' is missing') + if isinstance(plan_params, list): + return self._to_beneficiary_plan({ + **plan_params[0], + 'statusName': BeneficiaryStatus(plan_params[0]['status']).name + }) + if isinstance(plan_params, dict): + return self._to_beneficiary_plan({ + **plan_params, + 'statusName': BeneficiaryStatus(plan_params.get('status', 0)).name + }) + raise TypeError(f'Internal error on getting plan params for ${beneficiary_address}') + + def __get_plan_raw(self, plan_id: PlanId) -> List[Any]: + return list(self.contract.functions.getPlan(plan_id).call()) @format_fields(PLAN_FIELDS) - def get_plan(self, plan_id: int) -> dict: + def get_untyped_plan(self, plan_id: PlanId) -> List[Any]: return self.__get_plan_raw(plan_id) - def get_all_plans(self) -> dict: + def get_plan(self, plan_id: PlanId) -> Plan: + untyped_plan = self.get_untyped_plan(plan_id) + if untyped_plan is None: + raise ValueError('Plan ', plan_id, ' is missing') + if isinstance(untyped_plan, list): + return self._to_plan(untyped_plan[0]) + if isinstance(untyped_plan, dict): + return self._to_plan(untyped_plan) + raise TypeError(plan_id) + + def get_all_plans(self) -> List[PlanWithId]: plans = [] for i in range(1, MAX_NUM_OF_PLANS): try: - plan = self.get_plan(i) - plan['planId'] = i + plan_id = PlanId(i) + plan = PlanWithId({**self.get_plan(plan_id), 'planId': plan_id}) plans.append(plan) except (ContractLogicError, ValueError): break return plans - def calculate_vested_amount(self, address: str) -> int: - return self.contract.functions.calculateVestedAmount(address).call() - - def get_finish_vesting_time(self, address: str) -> int: - return self.contract.functions.getFinishVestingTime(address).call() - - def get_lockup_period_end_timestamp(self, address: str) -> int: - return self.contract.functions.getLockupPeriodEndTimestamp(address).call() - - def get_time_of_next_vest(self, address: str) -> int: - return self.contract.functions.getTimeOfNextVest(address).call() + def calculate_vested_amount(self, address: ChecksumAddress) -> Wei: + return Wei(self.contract.functions.calculateVestedAmount(address).call()) + + def get_finish_vesting_time(self, address: ChecksumAddress) -> int: + return int(self.contract.functions.getFinishVestingTime(address).call()) + + def get_lockup_period_end_timestamp(self, address: ChecksumAddress) -> int: + return int(self.contract.functions.getLockupPeriodEndTimestamp(address).call()) + + def get_time_of_next_vest(self, address: ChecksumAddress) -> int: + return int(self.contract.functions.getTimeOfNextVest(address).call()) + + def _to_plan(self, untyped_plan: Dict[str, Any]) -> Plan: + return Plan({ + 'totalVestingDuration': int(untyped_plan['totalVestingDuration']), + 'vestingCliff': int(untyped_plan['vestingCliff']), + 'vestingIntervalTimeUnit': TimeUnit(untyped_plan['vestingIntervalTimeUnit']), + 'vestingInterval': int(untyped_plan['vestingInterval']), + 'isDelegationAllowed': bool(untyped_plan['isDelegationAllowed']), + 'isTerminatable': bool(untyped_plan['isTerminatable']) + }) + + def _to_beneficiary_plan(self, untyped_beneficiary_plan: Dict[str, Any]) -> BeneficiaryPlan: + return BeneficiaryPlan({ + 'status': BeneficiaryStatus(untyped_beneficiary_plan['status']), + 'statusName': str(untyped_beneficiary_plan['statusName']), + 'planId': PlanId(untyped_beneficiary_plan['planId']), + 'startMonth': int(untyped_beneficiary_plan['startMonth']), + 'fullAmount': Wei(untyped_beneficiary_plan['fullAmount']), + 'amountAfterLockup': Wei(untyped_beneficiary_plan['amountAfterLockup']) + }) diff --git a/skale/contracts/allocator/escrow.py b/skale/contracts/allocator/escrow.py index c65feda6..c221578a 100644 --- a/skale/contracts/allocator/escrow.py +++ b/skale/contracts/allocator/escrow.py @@ -18,32 +18,49 @@ # along with SKALE.py. If not, see . """ SKALE Allocator Core Escrow methods """ +from __future__ import annotations import functools +from typing import Any, Callable, TYPE_CHECKING -from skale.contracts.base_contract import BaseContract, transaction_method +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction +from web3.types import Wei + +from skale.contracts.allocator_contract import AllocatorContract +from skale.contracts.base_contract import transaction_method from skale.transactions.result import TxRes +from skale.types.delegation import DelegationId +from skale.types.validator import ValidatorId + +if TYPE_CHECKING: + from skale.contracts.allocator.allocator import Allocator -def beneficiary_escrow(transaction): +def beneficiary_escrow(transaction: Callable[..., TxRes]) -> Callable[..., TxRes]: @functools.wraps(transaction) - def wrapper(self, *args, beneficiary_address, **kwargs): + def wrapper( + self: AllocatorContract, + *args: Any, + beneficiary_address: ChecksumAddress, + **kwargs: Any + ) -> TxRes: self.contract = self.skale.instance.get_contract('Escrow', beneficiary_address) return transaction(self, *args, **kwargs) return wrapper -class Escrow(BaseContract): +class Escrow(AllocatorContract): @property @functools.lru_cache() - def allocator(self): + def allocator(self) -> Allocator: return self.skale.allocator - def init_contract(self, skale, address, abi) -> None: - self.contract = None + def init_contract(self, *args: Any) -> None: + self.contract = self.allocator.contract @beneficiary_escrow @transaction_method - def retrieve(self) -> TxRes: + def retrieve(self) -> ContractFunction: """Allows Holder to retrieve vested tokens from the Escrow contract :returns: Transaction results @@ -53,7 +70,7 @@ def retrieve(self) -> TxRes: @beneficiary_escrow @transaction_method - def retrieve_after_termination(self, address: str) -> TxRes: + def retrieve_after_termination(self, address: ChecksumAddress) -> ContractFunction: """Allows Core Owner to retrieve remaining transferrable escrow balance after Core holder termination. Slashed tokens are non-transferable @@ -64,7 +81,13 @@ def retrieve_after_termination(self, address: str) -> TxRes: @beneficiary_escrow @transaction_method - def delegate(self, validator_id: int, amount: int, delegation_period: int, info: str) -> TxRes: + def delegate( + self, + validator_id: ValidatorId, + amount: Wei, + delegation_period: int, + info: str + ) -> ContractFunction: """Allows Core holder to propose a delegation to a validator :param validator_id: ID of the validator to delegate tokens @@ -82,7 +105,7 @@ def delegate(self, validator_id: int, amount: int, delegation_period: int, info: @beneficiary_escrow @transaction_method - def request_undelegation(self, delegation_id: int) -> TxRes: + def request_undelegation(self, delegation_id: DelegationId) -> ContractFunction: """Allows Holder and Owner to request undelegation. Only Owner can request undelegation after Core holder is deactivated (upon holder termination) @@ -95,7 +118,7 @@ def request_undelegation(self, delegation_id: int) -> TxRes: @beneficiary_escrow @transaction_method - def withdraw_bounty(self, validator_id: int, to: str) -> TxRes: + def withdraw_bounty(self, validator_id: ValidatorId, to: ChecksumAddress) -> ContractFunction: """Allows Beneficiary and Vesting Owner to withdraw earned bounty. :param validator_id: ID of the validator @@ -109,7 +132,7 @@ def withdraw_bounty(self, validator_id: int, to: str) -> TxRes: @beneficiary_escrow @transaction_method - def cancel_pending_delegation(self, delegation_id: int) -> TxRes: + def cancel_pending_delegation(self, delegation_id: DelegationId) -> ContractFunction: """Cancel pending delegation request. :param delegation_id: ID of the delegation to cancel diff --git a/skale/dataclasses/delegation_status.py b/skale/contracts/allocator_contract.py similarity index 80% rename from skale/dataclasses/delegation_status.py rename to skale/contracts/allocator_contract.py index f25bd675..00eea17f 100644 --- a/skale/dataclasses/delegation_status.py +++ b/skale/contracts/allocator_contract.py @@ -16,15 +16,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . +from skale.contracts.base_contract import BaseContract +from skale.skale_allocator import SkaleAllocator -from enum import Enum - -class DelegationStatus(Enum): - PROPOSED = 0 - ACCEPTED = 1 - CANCELED = 2 - REJECTED = 3 - DELEGATED = 4 - UNDELEGATION_REQUESTED = 5 - COMPLETED = 6 +class AllocatorContract(BaseContract[SkaleAllocator]): + pass diff --git a/skale/contracts/base_contract.py b/skale/contracts/base_contract.py index fef65a85..48323c44 100644 --- a/skale/contracts/base_contract.py +++ b/skale/contracts/base_contract.py @@ -17,16 +17,19 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . """ SKALE base contract class """ - +from __future__ import annotations import logging from functools import wraps -from typing import Dict, Optional +from typing import Any, Callable, TYPE_CHECKING, Generic, TypeVar +from eth_typing import ChecksumAddress from web3 import Web3 +from web3.contract.contract import ContractFunction +from web3.types import ABI, Nonce, Wei import skale.config as config -from skale.transactions.result import TxRes -from skale.transactions.tools import make_dry_run_call, transaction_from_method, TxStatus +from skale.transactions.result import TxRes, TxStatus +from skale.transactions.tools import make_dry_run_call, transaction_from_method from skale.utils.web3_utils import ( DEFAULT_BLOCKS_TO_WAIT, get_eth_nonce, @@ -34,35 +37,74 @@ wait_for_confirmation_blocks ) +from skale.skale_base import SkaleBase from skale.utils.helper import to_camel_case +if TYPE_CHECKING: + pass + logger = logging.getLogger(__name__) -def transaction_method(transaction): +SkaleType = TypeVar('SkaleType', bound=SkaleBase) + + +class BaseContract(Generic[SkaleType]): + def __init__( + self, + skale: SkaleType, + name: str, + address: ChecksumAddress | str | bytes, + abi: ABI + ): + self.skale = skale + self.name = name + self.address = Web3.to_checksum_address(address) + self.init_contract(skale, self.address, abi) + + def init_contract(self, skale: SkaleBase, address: ChecksumAddress, abi: ABI) -> None: + self.contract = skale.web3.eth.contract(address=address, abi=abi) + + def __getattr__(self, attr: str) -> Callable[..., Any]: + """Fallback for contract calls""" + logger.debug("Calling contract function: %s", attr) + + def wrapper(*args: Any, **kw: Any) -> Any: + logger.debug('called with %r and %r' % (args, kw)) + camel_case_fn_name = to_camel_case(attr) + if hasattr(self.contract.functions, camel_case_fn_name): + return getattr(self.contract.functions, + camel_case_fn_name)(*args, **kw).call() + if hasattr(self.contract.functions, attr): + return getattr(self.contract.functions, + attr)(*args, **kw).call() + raise AttributeError(attr) + return wrapper + + +def transaction_method(transaction: Callable[..., ContractFunction]) -> Callable[..., TxRes]: @wraps(transaction) def wrapper( - self, - *args, - wait_for=True, - blocks_to_wait=DEFAULT_BLOCKS_TO_WAIT, - timeout=MAX_WAITING_TIME, - gas_limit=None, - gas_price=None, - nonce=None, - max_fee_per_gas=None, - max_priority_fee_per_gas=None, - value=0, - dry_run_only=False, - skip_dry_run=False, - raise_for_status=True, - multiplier=None, - priority=None, - confirmation_blocks=0, - meta: Optional[Dict] = None, - **kwargs - ): + self: BaseContract[SkaleType], + *args: Any, + wait_for: bool = True, + blocks_to_wait: int = DEFAULT_BLOCKS_TO_WAIT, + timeout: int = MAX_WAITING_TIME, + gas_limit: int | None = None, + gas_price: int | None = None, + nonce: Nonce | None = None, + max_fee_per_gas: int | None = None, + max_priority_fee_per_gas: int | None = None, + value: Wei = Wei(0), + dry_run_only: bool = False, + skip_dry_run: bool = False, + raise_for_status: bool = True, + multiplier: float | None = None, + priority: int | None = None, + confirmation_blocks: int = 0, + **kwargs: Any + ) -> TxRes: method = transaction(self, *args, **kwargs) nonce = get_eth_nonce(self.skale.web3, self.skale.wallet.address) @@ -70,13 +112,15 @@ def wrapper( call_result, tx_hash, receipt = None, None, None should_dry_run = not skip_dry_run and not config.DISABLE_DRY_RUN + dry_run_success = False if should_dry_run: call_result = make_dry_run_call(self.skale, method, gas_limit, value) if call_result.status == TxStatus.SUCCESS: - gas_limit = gas_limit or call_result.data['gas'] + gas_limit = gas_limit or int(call_result.data['gas']) + dry_run_success = True should_send = not dry_run_only and \ - (not should_dry_run or call_result.status == TxStatus.SUCCESS) + (not should_dry_run or dry_run_success) if should_send: gas_limit = gas_limit or config.DEFAULT_GAS_LIMIT @@ -95,12 +139,10 @@ def wrapper( tx, multiplier=multiplier, priority=priority, - method=method_name, - meta=meta + method=method_name ) - should_wait = tx_hash is not None and wait_for - if should_wait: + if tx_hash is not None and wait_for: receipt = self.skale.wallet.wait(tx_hash) should_confirm = receipt is not None and confirmation_blocks > 0 @@ -114,30 +156,3 @@ def wrapper( return tx_res return wrapper - - -class BaseContract: - def __init__(self, skale, name, address, abi): - self.skale = skale - self.name = name - self.address = Web3.to_checksum_address(address) - self.init_contract(skale, address, abi) - - def init_contract(self, skale, address, abi) -> None: - self.contract = skale.web3.eth.contract(address=address, abi=abi) - - def __getattr__(self, attr): - """Fallback for contract calls""" - logger.debug("Calling contract function: %s", attr) - - def wrapper(*args, **kw): - logger.debug('called with %r and %r' % (args, kw)) - camel_case_fn_name = to_camel_case(attr) - if hasattr(self.contract.functions, camel_case_fn_name): - return getattr(self.contract.functions, - camel_case_fn_name)(*args, **kw).call() - if hasattr(self.contract.functions, attr): - return getattr(self.contract.functions, - attr)(*args, **kw).call() - raise AttributeError(attr) - return wrapper diff --git a/skale/contracts/ima/__init__.py b/skale/contracts/ima/__init__.py index dd484307..465bdd4d 100644 --- a/skale/contracts/ima/__init__.py +++ b/skale/contracts/ima/__init__.py @@ -1,6 +1,10 @@ # flake8: noqa -from skale.contracts.contract_manager import ContractManager +from skale.contracts.manager.contract_manager import ContractManager from skale.contracts.base_contract import BaseContract, transaction_method from skale.contracts.ima.linker import Linker + +__all__ = [ + 'Linker' +] diff --git a/skale/contracts/ima/linker.py b/skale/contracts/ima/linker.py index 8fb40c48..9895a9c2 100644 --- a/skale/contracts/ima/linker.py +++ b/skale/contracts/ima/linker.py @@ -17,16 +17,22 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.base_contract import BaseContract, transaction_method +from typing import List +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction +from skale.contracts.base_contract import transaction_method +from skale.contracts.ima_contract import ImaContract +from skale.types.schain import SchainName -class Linker(BaseContract): + +class Linker(ImaContract): @transaction_method def connect_schain( self, - schain_name: str, - mainnet_contracts: list - ): + schain_name: SchainName, + mainnet_contracts: List[ChecksumAddress] + ) -> ContractFunction: return self.contract.functions.connectSchain( schain_name, mainnet_contracts diff --git a/skale/contracts/ima_contract.py b/skale/contracts/ima_contract.py new file mode 100644 index 00000000..ff45ab18 --- /dev/null +++ b/skale/contracts/ima_contract.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE.py +# +# Copyright (C) 2019-Present SKALE Labs +# +# SKALE.py is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# SKALE.py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with SKALE.py. If not, see . +from skale.contracts.base_contract import BaseContract +from skale.skale_ima import SkaleIma + + +class ImaContract(BaseContract[SkaleIma]): + pass diff --git a/skale/contracts/manager/__init__.py b/skale/contracts/manager/__init__.py index 475967bf..93f126e2 100644 --- a/skale/contracts/manager/__init__.py +++ b/skale/contracts/manager/__init__.py @@ -1,6 +1,6 @@ # flake8: noqa -from skale.contracts.contract_manager import ContractManager +from skale.contracts.manager.contract_manager import ContractManager from skale.contracts.base_contract import BaseContract, transaction_method from skale.contracts.manager.manager import Manager @@ -31,3 +31,27 @@ from skale.contracts.manager.sync_manager import SyncManager from skale.contracts.manager.test.time_helpers_with_debug import TimeHelpersWithDebug + +__all__ = [ + 'BountyV2', + 'ConstantsHolder', + 'ContractManager', + 'DelegationController', + 'DelegationPeriodManager', + 'Distributor', + 'DKG', + 'KeyStorage', + 'Manager', + 'NodeRotation', + 'Nodes', + 'Punisher', + 'SChains', + 'SChainsInternal', + 'SlashingTable', + 'SyncManager', + 'TimeHelpersWithDebug', + 'Token', + 'TokenState', + 'ValidatorService', + 'Wallets' +] diff --git a/skale/contracts/manager/bounty_v2.py b/skale/contracts/manager/bounty_v2.py index 3fba4724..9df79de2 100644 --- a/skale/contracts/manager/bounty_v2.py +++ b/skale/contracts/manager/bounty_v2.py @@ -17,17 +17,20 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract -class BountyV2(BaseContract): + +class BountyV2(SkaleManagerContract): @transaction_method - def grant_role(self, role: bytes, owner: str) -> TxRes: + def grant_role(self, role: bytes, owner: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, owner) - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) - def bounty_reduction_manager_role(self): - return self.contract.functions.BOUNTY_REDUCTION_MANAGER_ROLE().call() + def bounty_reduction_manager_role(self) -> bytes: + return bytes(self.contract.functions.BOUNTY_REDUCTION_MANAGER_ROLE().call()) diff --git a/skale/contracts/manager/constants_holder.py b/skale/contracts/manager/constants_holder.py index dda23842..d912bf6b 100644 --- a/skale/contracts/manager/constants_holder.py +++ b/skale/contracts/manager/constants_holder.py @@ -17,40 +17,43 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract -class ConstantsHolder(BaseContract): + +class ConstantsHolder(SkaleManagerContract): @transaction_method - def set_periods(self, new_reward_period, new_delta_period): + def set_periods(self, new_reward_period: int, new_delta_period: int) -> ContractFunction: return self.contract.functions.setPeriods( new_reward_period, new_delta_period ) - def get_reward_period(self): - return self.contract.functions.rewardPeriod().call() + def get_reward_period(self) -> int: + return int(self.contract.functions.rewardPeriod().call()) - def get_delta_period(self): - return self.contract.functions.deltaPeriod().call() + def get_delta_period(self) -> int: + return int(self.contract.functions.deltaPeriod().call()) @transaction_method - def set_check_time(self, new_check_time): + def set_check_time(self, new_check_time: int) -> ContractFunction: return self.contract.functions.setCheckTime(new_check_time) - def get_check_time(self): - return self.contract.functions.checkTime().call() + def get_check_time(self) -> int: + return int(self.contract.functions.checkTime().call()) @transaction_method - def set_latency(self, new_allowable_latency): + def set_latency(self, new_allowable_latency: int) -> ContractFunction: return self.contract.functions.setLatency(new_allowable_latency) - def get_latency(self): - return self.contract.functions.allowableLatency().call() + def get_latency(self) -> int: + return int(self.contract.functions.allowableLatency().call()) - def get_first_delegation_month(self): - return self.contract.functions.firstDelegationsMonth().call() + def get_first_delegation_month(self) -> int: + return int(self.contract.functions.firstDelegationsMonth().call()) def msr(self) -> int: """Minimum staking requirement to create a node. @@ -58,41 +61,41 @@ def msr(self) -> int: :returns: MSR (in wei) :rtype: int """ - return self.contract.functions.msr().call() + return int(self.contract.functions.msr().call()) @transaction_method - def _set_msr(self, new_msr: int) -> None: + def _set_msr(self, new_msr: int) -> ContractFunction: """For internal usage only""" return self.contract.functions.setMSR(new_msr) def get_launch_timestamp(self) -> int: - return self.contract.functions.launchTimestamp().call() + return int(self.contract.functions.launchTimestamp().call()) @transaction_method - def set_launch_timestamp(self, launch_timestamp: int): + def set_launch_timestamp(self, launch_timestamp: int) -> ContractFunction: return self.contract.functions.setLaunchTimestamp(launch_timestamp) @transaction_method - def set_rotation_delay(self, rotation_delay: int) -> None: + def set_rotation_delay(self, rotation_delay: int) -> ContractFunction: """For internal usage only""" return self.contract.functions.setRotationDelay(rotation_delay) def get_rotation_delay(self) -> int: - return self.contract.functions.rotationDelay().call() + return int(self.contract.functions.rotationDelay().call()) def get_dkg_timeout(self) -> int: - return self.contract.functions.complaintTimeLimit().call() + return int(self.contract.functions.complaintTimeLimit().call()) @transaction_method - def set_complaint_timelimit(self, complaint_timelimit: int): + def set_complaint_timelimit(self, complaint_timelimit: int) -> ContractFunction: return self.contract.functions.setComplaintTimeLimit(complaint_timelimit) @transaction_method - def grant_role(self, role: bytes, address: str) -> TxRes: + def grant_role(self, role: bytes, address: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, address) def constants_holder_role(self) -> bytes: - return self.contract.functions.CONSTANTS_HOLDER_MANAGER_ROLE().call() + return bytes(self.contract.functions.CONSTANTS_HOLDER_MANAGER_ROLE().call()) - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) diff --git a/skale/contracts/contract_manager.py b/skale/contracts/manager/contract_manager.py similarity index 72% rename from skale/contracts/contract_manager.py rename to skale/contracts/manager/contract_manager.py index 5b33de65..d2fe91ce 100644 --- a/skale/contracts/contract_manager.py +++ b/skale/contracts/manager/contract_manager.py @@ -19,16 +19,18 @@ """ SKALE Contract manager class """ from Crypto.Hash import keccak +from eth_typing import ChecksumAddress +from web3 import Web3 -from skale.contracts.base_contract import BaseContract +from skale.contracts.skale_manager_contract import SkaleManagerContract from skale.utils.helper import add_0x_prefix -class ContractManager(BaseContract): - def get_contract_address(self, name): +class ContractManager(SkaleManagerContract): + def get_contract_address(self, name: str) -> ChecksumAddress: contract_hash = add_0x_prefix(self.get_contract_hash_by_name(name)) - return self.contract.functions.contracts(contract_hash).call() + return Web3.to_checksum_address(self.contract.functions.contracts(contract_hash).call()) - def get_contract_hash_by_name(self, name): + def get_contract_hash_by_name(self, name: str) -> str: keccak_hash = keccak.new(data=name.encode("utf8"), digest_bits=256) return keccak_hash.hexdigest() diff --git a/skale/contracts/manager/delegation/delegation_controller.py b/skale/contracts/manager/delegation/delegation_controller.py index adcb2e1b..b54853fe 100644 --- a/skale/contracts/manager/delegation/delegation_controller.py +++ b/skale/contracts/manager/delegation/delegation_controller.py @@ -17,10 +17,17 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.base_contract import BaseContract, transaction_method +from typing import Any, Dict, List + +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction +from web3.types import Wei + +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract +from skale.types.delegation import Delegation, DelegationId, DelegationStatus, FullDelegation +from skale.types.validator import ValidatorId from skale.utils.helper import format_fields -from skale.transactions.result import TxRes -from skale.dataclasses.delegation_status import DelegationStatus FIELDS = [ @@ -29,67 +36,82 @@ ] -class DelegationController(BaseContract): +class DelegationController(SkaleManagerContract): """Wrapper for DelegationController.sol functions""" @format_fields(FIELDS) - def get_delegation(self, delegation_id: int) -> dict: + def get_untyped_delegation(self, delegation_id: DelegationId) -> List[Any]: """Returns delegation structure. :returns: Info about delegation request - :rtype: dict + :rtype: list """ return self.__raw_get_delegation(delegation_id) - def get_delegation_full(self, delegation_id: int) -> dict: + def get_delegation(self, delegation_id: DelegationId) -> Delegation: + delegation = self.get_untyped_delegation(delegation_id) + if delegation is None: + raise ValueError("Can't get delegation with id ", delegation_id) + if isinstance(delegation, dict): + return self._to_delegation(delegation) + if isinstance(delegation, list): + return self._to_delegation(delegation[0]) + raise TypeError(delegation_id) + + def get_delegation_full(self, delegation_id: DelegationId) -> FullDelegation: """Returns delegation structure with ID and status fields. :returns: Info about delegation request :rtype: dict """ delegation = self.get_delegation(delegation_id) - delegation['id'] = delegation_id - delegation['status'] = self._get_delegation_status(delegation_id) - return delegation + return FullDelegation({ + 'id': delegation_id, + 'status': self._get_delegation_status(delegation_id), + **delegation + }) - def __raw_get_delegation(self, delegation_id: int) -> list: + def __raw_get_delegation(self, delegation_id: DelegationId) -> List[Any]: """Returns raw delegation fields. :returns: Info about delegation request :rtype: list """ - return self.contract.functions.getDelegation(delegation_id).call() + return list(self.contract.functions.getDelegation(delegation_id).call()) - def _get_delegation_ids_by_validator(self, validator_id: int) -> list: + def _get_delegation_ids_by_validator(self, validator_id: ValidatorId) -> List[DelegationId]: delegation_ids_len = self._get_delegation_ids_len_by_validator( validator_id) return [ - self.contract.functions.delegationsByValidator( - validator_id, _id).call() + DelegationId( + self.contract.functions.delegationsByValidator(validator_id, _id).call() + ) for _id in range(delegation_ids_len) ] - def _get_delegation_ids_by_holder(self, address: str) -> list: + def _get_delegation_ids_by_holder(self, address: ChecksumAddress) -> List[DelegationId]: delegation_ids_len = self._get_delegation_ids_len_by_holder(address) return [ - self.contract.functions.delegationsByHolder(address, _id).call() + DelegationId( + self.contract.functions.delegationsByHolder(address, _id).call() + ) for _id in range(delegation_ids_len) ] - def _get_delegation_ids_len_by_validator(self, validator_id: int) -> list: - return self.contract.functions.getDelegationsByValidatorLength(validator_id).call() + def _get_delegation_ids_len_by_validator(self, validator_id: ValidatorId) -> int: + return int(self.contract.functions.getDelegationsByValidatorLength(validator_id).call()) - def _get_delegation_ids_len_by_holder(self, address: str) -> list: - return self.contract.functions.getDelegationsByHolderLength(address).call() + def _get_delegation_ids_len_by_holder(self, address: ChecksumAddress) -> int: + return int(self.contract.functions.getDelegationsByHolderLength(address).call()) - def _get_delegation_state_index(self, delegation_id: int) -> str: - return self.contract.functions.getState(delegation_id).call() + def _get_delegation_state_index(self, delegation_id: DelegationId) -> int: + return int(self.contract.functions.getState(delegation_id).call()) - def _get_delegation_status(self, delegation_id: int) -> str: + def _get_delegation_status(self, delegation_id: DelegationId) -> DelegationStatus: index = self._get_delegation_state_index(delegation_id) - return DelegationStatus(index).name + return DelegationStatus(index) - def get_all_delegations(self, delegation_ids: list) -> list: + def get_all_delegations(self, delegation_ids: List[DelegationId]) -> List[FullDelegation]: """Returns list of formatted delegations with particular status. :param delegation_ids: List of delegation IDs @@ -102,7 +124,7 @@ def get_all_delegations(self, delegation_ids: list) -> list: for _id in delegation_ids ] - def get_all_delegations_by_holder(self, address: str) -> list: + def get_all_delegations_by_holder(self, address: ChecksumAddress) -> List[FullDelegation]: """Returns list of formatted delegations for token holder. :param address: Ethereum address @@ -113,7 +135,7 @@ def get_all_delegations_by_holder(self, address: str) -> list: delegation_ids = self._get_delegation_ids_by_holder(address) return self.get_all_delegations(delegation_ids) - def get_all_delegations_by_validator(self, validator_id: int) -> list: + def get_all_delegations_by_validator(self, validator_id: ValidatorId) -> List[FullDelegation]: """Returns list of formatted delegations for validator. :param validator_id: ID of the validator @@ -121,12 +143,17 @@ def get_all_delegations_by_validator(self, validator_id: int) -> list: :returns: List of formatted delegations :rtype: list """ - validator_id = int(validator_id) delegation_ids = self._get_delegation_ids_by_validator(validator_id) return self.get_all_delegations(delegation_ids) @transaction_method - def delegate(self, validator_id: int, amount: int, delegation_period: int, info: str) -> TxRes: + def delegate( + self, + validator_id: ValidatorId, + amount: Wei, + delegation_period: int, + info: str + ) -> ContractFunction: """Creates request to delegate amount of tokens to validator_id. :param validator_id: ID of the validator to delegate tokens @@ -143,7 +170,7 @@ def delegate(self, validator_id: int, amount: int, delegation_period: int, info: return self.contract.functions.delegate(validator_id, amount, delegation_period, info) @transaction_method - def accept_pending_delegation(self, delegation_id: int) -> TxRes: + def accept_pending_delegation(self, delegation_id: DelegationId) -> ContractFunction: """Accepts a pending delegation by delegation ID. :param delegation_id: Delegation ID to accept @@ -154,7 +181,7 @@ def accept_pending_delegation(self, delegation_id: int) -> TxRes: return self.contract.functions.acceptPendingDelegation(delegation_id) @transaction_method - def cancel_pending_delegation(self, delegation_id: int) -> TxRes: + def cancel_pending_delegation(self, delegation_id: DelegationId) -> ContractFunction: """Cancel pending delegation request. :param delegation_id: ID of the delegation to cancel @@ -165,7 +192,7 @@ def cancel_pending_delegation(self, delegation_id: int) -> TxRes: return self.contract.functions.cancelPendingDelegation(delegation_id) @transaction_method - def request_undelegation(self, delegation_id: int) -> TxRes: + def request_undelegation(self, delegation_id: DelegationId) -> ContractFunction: """ This method is for undelegating request in the end of delegation period (3/6/12 months) @@ -176,7 +203,7 @@ def request_undelegation(self, delegation_id: int) -> TxRes: """ return self.contract.functions.requestUndelegation(delegation_id) - def get_delegated_to_validator_now(self, validator_id: int) -> int: + def get_delegated_to_validator_now(self, validator_id: ValidatorId) -> Wei: """Amount of delegated tokens to the validator :param validator_id: ID of the validator @@ -184,9 +211,9 @@ def get_delegated_to_validator_now(self, validator_id: int) -> int: :returns: Amount of delegated tokens :rtype: int """ - return self.contract.functions.getAndUpdateDelegatedToValidatorNow(validator_id).call() + return Wei(self.contract.functions.getAndUpdateDelegatedToValidatorNow(validator_id).call()) - def get_delegated_to_validator(self, validator_id: int, month: int) -> int: + def get_delegated_to_validator(self, validator_id: ValidatorId, month: int) -> Wei: """Amount of delegated tokens to the validator unil month :param validator_id: ID of the validator @@ -197,9 +224,9 @@ def get_delegated_to_validator(self, validator_id: int, month: int) -> int: :rtype: int """ - return self.contract.functions.getDelegatedToValidator(validator_id, month).call() + return Wei(self.contract.functions.getDelegatedToValidator(validator_id, month).call()) - def get_delegated_amount(self, address: str) -> int: + def get_delegated_amount(self, address: ChecksumAddress) -> Wei: """Amount of delegated tokens by token holder :param address: Token holder address @@ -207,4 +234,16 @@ def get_delegated_amount(self, address: str) -> int: :returns: Amount of delegated tokens :rtype: int """ - return self.contract.functions.getAndUpdateDelegatedAmount(address).call() + return Wei(self.contract.functions.getAndUpdateDelegatedAmount(address).call()) + + def _to_delegation(self, delegation: Dict[str, Any]) -> Delegation: + return Delegation({ + 'address': ChecksumAddress(delegation['address']), + 'validator_id': ValidatorId(delegation['validator_id']), + 'amount': Wei(delegation['amount']), + 'delegation_period': int(delegation['delegation_period']), + 'created': int(delegation['created']), + 'started': int(delegation['started']), + 'finished': int(delegation['finished']), + 'info': str(delegation['info']) + }) diff --git a/skale/contracts/manager/delegation/delegation_period_manager.py b/skale/contracts/manager/delegation/delegation_period_manager.py index 3b753bfc..316b07e8 100644 --- a/skale/contracts/manager/delegation/delegation_period_manager.py +++ b/skale/contracts/manager/delegation/delegation_period_manager.py @@ -17,32 +17,35 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract -class DelegationPeriodManager(BaseContract): + +class DelegationPeriodManager(SkaleManagerContract): """Wrapper for DelegationPeriodManager.sol functions""" @transaction_method def set_delegation_period(self, months_count: int, - stake_multiplier: int) -> None: + stake_multiplier: int) -> ContractFunction: return self.contract.functions.setDelegationPeriod( monthsCount=months_count, stakeMultiplier=stake_multiplier ) def is_delegation_period_allowed(self, months_count: int) -> bool: - return self.contract.functions.isDelegationPeriodAllowed( + return bool(self.contract.functions.isDelegationPeriodAllowed( monthsCount=months_count - ).call() + ).call()) @transaction_method - def grant_role(self, role: bytes, address: str) -> TxRes: + def grant_role(self, role: bytes, address: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, address) def delegation_period_setter_role(self) -> bytes: - return self.contract.functions.DELEGATION_PERIOD_SETTER_ROLE().call() + return bytes(self.contract.functions.DELEGATION_PERIOD_SETTER_ROLE().call()) - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) diff --git a/skale/contracts/manager/delegation/distributor.py b/skale/contracts/manager/delegation/distributor.py index d3312a5d..4d565b79 100644 --- a/skale/contracts/manager/delegation/distributor.py +++ b/skale/contracts/manager/delegation/distributor.py @@ -18,27 +18,42 @@ # along with SKALE.py. If not, see . from functools import wraps +from typing import Any, Callable, Tuple, TypedDict -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction +from web3.types import Wei +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract +from skale.types.validator import ValidatorId -def formatter(method): + +class EarnedData(TypedDict): + earned: Wei + end_month: int + + +def formatter(method: Callable[..., Tuple[Wei, int]]) -> Callable[..., EarnedData]: @wraps(method) - def wrapper(self, *args, **kwargs): + def wrapper(self: SkaleManagerContract, *args: Any, **kwargs: Any) -> EarnedData: res = method(self, *args, **kwargs) - return { + return EarnedData({ 'earned': res[0], 'end_month': res[1], - } + }) return wrapper -class Distributor(BaseContract): +class Distributor(SkaleManagerContract): """Wrapper for Distributor.sol functions""" @formatter - def get_earned_bounty_amount(self, validator_id: int, address: str) -> dict: + def get_earned_bounty_amount( + self, + validator_id: ValidatorId, + address: ChecksumAddress + ) -> Tuple[Wei, int]: """Get earned bounty amount for the validator :param validator_id: ID of the validator @@ -46,12 +61,12 @@ def get_earned_bounty_amount(self, validator_id: int, address: str) -> dict: :returns: Earned bounty amount and end month :rtype: dict """ - return self.contract.functions.getAndUpdateEarnedBountyAmount(validator_id).call({ + return tuple(self.contract.functions.getAndUpdateEarnedBountyAmount(validator_id).call({ 'from': address - }) + })) @formatter - def get_earned_fee_amount(self, address: str) -> dict: + def get_earned_fee_amount(self, address: str) -> Tuple[Wei, int]: """Get earned fee amount for the address :param address: Address of the validator @@ -59,12 +74,12 @@ def get_earned_fee_amount(self, address: str) -> dict: :returns: Earned bounty amount and end month :rtype: dict """ - return self.contract.functions.getEarnedFeeAmount().call({ + return tuple(self.contract.functions.getEarnedFeeAmount().call({ 'from': address - }) + })) @transaction_method - def withdraw_bounty(self, validator_id: int, to: str) -> TxRes: + def withdraw_bounty(self, validator_id: ValidatorId, to: ChecksumAddress) -> ContractFunction: """Withdraw earned bounty to specified address :param validator_id: ID of the validator @@ -77,7 +92,7 @@ def withdraw_bounty(self, validator_id: int, to: str) -> TxRes: return self.contract.functions.withdrawBounty(validator_id, to) @transaction_method - def withdraw_fee(self, to: str) -> TxRes: + def withdraw_fee(self, to: ChecksumAddress) -> ContractFunction: """Withdraw earned fee to specified address :param to: Address to transfer bounty diff --git a/skale/contracts/manager/delegation/slashing_table.py b/skale/contracts/manager/delegation/slashing_table.py index c2c31f00..240a4dd5 100644 --- a/skale/contracts/manager/delegation/slashing_table.py +++ b/skale/contracts/manager/delegation/slashing_table.py @@ -1,12 +1,16 @@ -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction +from web3.types import Wei +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract -class SlashingTable(BaseContract): + +class SlashingTable(SkaleManagerContract): """ Wrapper for SlashingTable.sol functions """ @transaction_method - def set_penalty(self, offense, penalty) -> TxRes: + def set_penalty(self, offense: str, penalty: Wei) -> ContractFunction: """ Set slashing penalty :param offense: reason of slashing :type offense: str @@ -16,20 +20,20 @@ def set_penalty(self, offense, penalty) -> TxRes: """ return self.contract.functions.setPenalty(offense, penalty) - def get_penalty(self, offense) -> int: + def get_penalty(self, offense: str) -> Wei: """ Get slashing penalty value :param offense: reason of slashing :type offense: str :rtype: int """ - return self.contract.functions.getPenalty(offense).call() + return Wei(self.contract.functions.getPenalty(offense).call()) @transaction_method - def grant_role(self, role: bytes, address: str) -> TxRes: + def grant_role(self, role: bytes, address: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, address) def penalty_setter_role(self) -> bytes: - return self.contract.functions.PENALTY_SETTER_ROLE().call() + return bytes(self.contract.functions.PENALTY_SETTER_ROLE().call()) - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) diff --git a/skale/contracts/manager/delegation/token_state.py b/skale/contracts/manager/delegation/token_state.py index 1871e0e9..6a14b9de 100644 --- a/skale/contracts/manager/delegation/token_state.py +++ b/skale/contracts/manager/delegation/token_state.py @@ -17,28 +17,32 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction +from web3.types import Wei +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract -class TokenState(BaseContract): + +class TokenState(SkaleManagerContract): """Wrapper for TokenState.sol functions""" - def get_and_update_locked_amount(self, holder_address: str) -> int: + def get_and_update_locked_amount(self, holder_address: ChecksumAddress) -> Wei: """This method is for check quantity of `freezed` tokens :param holder_address: Address of the holder :type holder_address: str :returns: :rtype: int """ - return self.contract.functions.getAndUpdateLockedAmount(holder_address).call() + return Wei(self.contract.functions.getAndUpdateLockedAmount(holder_address).call()) @transaction_method - def grant_role(self, role: bytes, owner: str) -> TxRes: + def grant_role(self, role: bytes, owner: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, owner) - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) - def locker_manager_role(self): - return self.contract.functions.LOCKER_MANAGER_ROLE().call() + def locker_manager_role(self) -> bytes: + return bytes(self.contract.functions.LOCKER_MANAGER_ROLE().call()) diff --git a/skale/contracts/manager/delegation/validator_service.py b/skale/contracts/manager/delegation/validator_service.py index ffba3305..534e759a 100644 --- a/skale/contracts/manager/delegation/validator_service.py +++ b/skale/contracts/manager/delegation/validator_service.py @@ -17,13 +17,17 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . +from typing import Any, Dict, List +from eth_typing import ChecksumAddress from web3 import Web3 +from web3.contract.contract import ContractFunction +from web3.types import Wei -from skale.contracts.base_contract import BaseContract, transaction_method +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract +from skale.types.validator import Validator, ValidatorId, ValidatorWithId from skale.utils.helper import format_fields -from skale.transactions.result import TxRes - FIELDS = [ 'name', 'validator_address', 'requested_address', 'description', 'fee_rate', @@ -32,19 +36,19 @@ ] -class ValidatorService(BaseContract): +class ValidatorService(SkaleManagerContract): """Wrapper for ValidatorService.sol functions""" - def __get_raw(self, _id) -> list: + def __get_raw(self, _id: ValidatorId) -> List[Any]: """Returns raw validator info. :returns: Raw validator info :rtype: list """ - return self.contract.functions.validators(_id).call() + return list(self.contract.functions.validators(_id).call()) @format_fields(FIELDS) - def get(self, _id) -> dict: + def untyped_get(self, _id: ValidatorId) -> List[Any]: """Returns validator info. :returns: Validator info @@ -55,25 +59,34 @@ def get(self, _id) -> dict: validator.append(trusted) return validator - def get_with_id(self, _id) -> dict: + def get(self, _id: ValidatorId) -> Validator: + untyped_validator = self.untyped_get(_id) + if untyped_validator is None: + raise ValueError('Validator with id ', _id, ' is missing') + if isinstance(untyped_validator, dict): + return self._to_validator(untyped_validator) + if isinstance(untyped_validator, list): + return self._to_validator(untyped_validator[0]) + raise TypeError(_id) + + def get_with_id(self, _id: ValidatorId) -> ValidatorWithId: """Returns validator info with ID. :returns: Validator info with ID :rtype: dict """ validator = self.get(_id) - validator['id'] = _id - return validator + return ValidatorWithId({'id': _id, **validator}) - def number_of_validators(self): + def number_of_validators(self) -> int: """Returns number of registered validators. :returns: List of validators :rtype: int """ - return self.contract.functions.numberOfValidators().call() + return int(self.contract.functions.numberOfValidators().call()) - def ls(self, trusted_only=False): + def ls(self, trusted_only: bool = False) -> List[ValidatorWithId]: """Returns list of registered validators. :returns: List of validators @@ -84,30 +97,42 @@ def ls(self, trusted_only=False): self.get_with_id(val_id) for val_id in self.get_trusted_validator_ids() ] if trusted_only else [ - self.get_with_id(val_id) + self.get_with_id(ValidatorId(val_id)) for val_id in range(1, number_of_validators + 1) ] return validators - def get_linked_addresses_by_validator_address(self, address: str) -> list: + def get_linked_addresses_by_validator_address( + self, + address: ChecksumAddress + ) -> List[ChecksumAddress]: """Returns list of node addresses linked to the validator address. :returns: List of node addresses :rtype: list """ - return self.contract.functions.getMyNodesAddresses().call({ - 'from': address - }) + return [ + Web3.to_checksum_address(address) + for address + in self.contract.functions.getMyNodesAddresses().call({'from': address}) + ] - def get_linked_addresses_by_validator_id(self, validator_id: int) -> list: + def get_linked_addresses_by_validator_id( + self, + validator_id: ValidatorId + ) -> List[ChecksumAddress]: """Returns list of node addresses linked to the validator ID. :returns: List of node addresses :rtype: list """ - return self.contract.functions.getNodeAddresses(validator_id).call() + return [ + Web3.to_checksum_address(address) + for address + in self.contract.functions.getNodeAddresses(validator_id).call() + ] - def is_main_address(self, validator_address: str) -> bool: + def is_main_address(self, validator_address: ChecksumAddress) -> bool: """Checks if provided address is the main validator address :returns: True if provided address is the main validator address, otherwise False @@ -125,59 +150,63 @@ def is_main_address(self, validator_address: str) -> bool: return validator_address == validator['validator_address'] - def validator_address_exists(self, validator_address: str) -> bool: + def validator_address_exists(self, validator_address: ChecksumAddress) -> bool: """Checks if there is a validator with provided address :returns: True if validator exists, otherwise False :rtype: bool """ - return self.contract.functions.validatorAddressExists(validator_address).call() + return bool(self.contract.functions.validatorAddressExists(validator_address).call()) - def validator_exists(self, validator_id: str) -> bool: + def validator_exists(self, validator_id: ValidatorId) -> bool: """Checks if there is a validator with provided ID :returns: True if validator exists, otherwise False :rtype: bool """ - return self.contract.functions.validatorExists(validator_id).call() + return bool(self.contract.functions.validatorExists(validator_id).call()) - def validator_id_by_address(self, validator_address: str) -> int: + def validator_id_by_address(self, validator_address: ChecksumAddress) -> ValidatorId: """Returns validator ID by validator address :returns: Validator ID :rtype: int """ - return self.contract.functions.getValidatorId(validator_address).call() + return ValidatorId(self.contract.functions.getValidatorId(validator_address).call()) - def get_trusted_validator_ids(self) -> list: + def get_trusted_validator_ids(self) -> List[ValidatorId]: """Returns list of trusted validators id. :returns: List of trusted validators id :rtype: list """ - return self.contract.functions.getTrustedValidators().call() + return [ + ValidatorId(id) + for id + in self.contract.functions.getTrustedValidators().call() + ] @transaction_method - def _enable_validator(self, validator_id: int) -> TxRes: + def _enable_validator(self, validator_id: ValidatorId) -> ContractFunction: """For internal usage only""" return self.contract.functions.enableValidator(validator_id) @transaction_method - def _disable_validator(self, validator_id: int) -> TxRes: + def _disable_validator(self, validator_id: ValidatorId) -> ContractFunction: """For internal usage only""" return self.contract.functions.disableValidator(validator_id) - def _is_authorized_validator(self, validator_id: int) -> bool: + def _is_authorized_validator(self, validator_id: ValidatorId) -> bool: """For internal usage only""" - return self.contract.functions.isAuthorizedValidator(validator_id).call() + return bool(self.contract.functions.isAuthorizedValidator(validator_id).call()) - def is_accepting_new_requests(self, validator_id: int) -> bool: + def is_accepting_new_requests(self, validator_id: ValidatorId) -> bool: """For internal usage only""" - return self.contract.functions.isAcceptingNewRequests(validator_id).call() + return bool(self.contract.functions.isAcceptingNewRequests(validator_id).call()) @transaction_method def register_validator(self, name: str, description: str, fee_rate: int, - min_delegation_amount: int) -> TxRes: + min_delegation_amount: int) -> ContractFunction: """Registers a new validator in the SKALE Manager contracts. :param name: Validator name @@ -194,13 +223,13 @@ def register_validator(self, name: str, description: str, fee_rate: int, return self.contract.functions.registerValidator( name, description, fee_rate, min_delegation_amount) - def get_link_node_signature(self, validator_id: int) -> str: + def get_link_node_signature(self, validator_id: ValidatorId) -> str: unsigned_hash = Web3.solidity_keccak(['uint256'], [validator_id]) signed_hash = self.skale.wallet.sign_hash(unsigned_hash.hex()) return signed_hash.signature.hex() @transaction_method - def link_node_address(self, node_address: str, signature: str) -> TxRes: + def link_node_address(self, node_address: ChecksumAddress, signature: str) -> ContractFunction: """Link node address to your validator account. :param node_address: Address of the node to link @@ -213,7 +242,7 @@ def link_node_address(self, node_address: str, signature: str) -> TxRes: return self.contract.functions.linkNodeAddress(node_address, signature) @transaction_method - def unlink_node_address(self, node_address: str) -> TxRes: + def unlink_node_address(self, node_address: ChecksumAddress) -> ContractFunction: """Unlink node address from your validator account. :param node_address: Address of the node to unlink @@ -224,7 +253,7 @@ def unlink_node_address(self, node_address: str) -> TxRes: return self.contract.functions.unlinkNodeAddress(node_address) @transaction_method - def disable_whitelist(self) -> TxRes: + def disable_whitelist(self) -> ContractFunction: """ Disable validator whitelist. Master key only transaction. :returns: Transaction results :rtype: TxRes @@ -236,18 +265,18 @@ def get_use_whitelist(self) -> bool: :returns: useWhitelist value :rtype: bool """ - return self.contract.functions.useWhitelist().call() + return bool(self.contract.functions.useWhitelist().call()) - def get_and_update_bond_amount(self, validator_id: int) -> int: + def get_and_update_bond_amount(self, validator_id: ValidatorId) -> int: """Return amount of token that validator delegated to himself :param validator_id: id of the validator :returns: :rtype: int """ - return self.contract.functions.getAndUpdateBondAmount(validator_id).call() + return int(self.contract.functions.getAndUpdateBondAmount(validator_id).call()) @transaction_method - def set_validator_mda(self, minimum_delegation_amount: int) -> TxRes: + def set_validator_mda(self, minimum_delegation_amount: Wei) -> ContractFunction: """ Allows a validator to set the minimum delegation amount. :param new_minimum_delegation_amount: Minimum delegation amount @@ -258,7 +287,7 @@ def set_validator_mda(self, minimum_delegation_amount: int) -> TxRes: return self.contract.functions.setValidatorMDA(minimum_delegation_amount) @transaction_method - def request_for_new_address(self, new_validator_address: str) -> TxRes: + def request_for_new_address(self, new_validator_address: ChecksumAddress) -> ContractFunction: """ Allows a validator to request a new address. :param new_validator_address: New validator address @@ -269,7 +298,7 @@ def request_for_new_address(self, new_validator_address: str) -> TxRes: return self.contract.functions.requestForNewAddress(new_validator_address) @transaction_method - def confirm_new_address(self, validator_id: int) -> TxRes: + def confirm_new_address(self, validator_id: ValidatorId) -> ContractFunction: """ Confirm change of the address. :param validator_id: ID of the validator @@ -280,7 +309,7 @@ def confirm_new_address(self, validator_id: int) -> TxRes: return self.contract.functions.confirmNewAddress(validator_id) @transaction_method - def set_validator_name(self, new_name: str) -> TxRes: + def set_validator_name(self, new_name: str) -> ContractFunction: """ Allows a validator to change the name. :param new_name: New validator name @@ -291,7 +320,7 @@ def set_validator_name(self, new_name: str) -> TxRes: return self.contract.functions.setValidatorName(new_name) @transaction_method - def set_validator_description(self, new_description: str) -> TxRes: + def set_validator_description(self, new_description: str) -> ContractFunction: """ Allows a validator to change the name. :param new_description: New validator description @@ -302,11 +331,24 @@ def set_validator_description(self, new_description: str) -> TxRes: return self.contract.functions.setValidatorDescription(new_description) @transaction_method - def grant_role(self, role: bytes, address: str) -> TxRes: + def grant_role(self, role: bytes, address: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, address) def validator_manager_role(self) -> bytes: - return self.contract.functions.VALIDATOR_MANAGER_ROLE().call() - - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + return bytes(self.contract.functions.VALIDATOR_MANAGER_ROLE().call()) + + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) + + def _to_validator(self, untyped_validator: Dict[str, Any]) -> Validator: + return Validator({ + 'name': str(untyped_validator['name']), + 'validator_address': ChecksumAddress(untyped_validator['validator_address']), + 'requested_address': ChecksumAddress(untyped_validator['requested_address']), + 'description': str(untyped_validator['description']), + 'fee_rate': int(untyped_validator['fee_rate']), + 'registration_time': int(untyped_validator['registration_time']), + 'minimum_delegation_amount': Wei(untyped_validator['minimum_delegation_amount']), + 'accept_new_requests': bool(untyped_validator['accept_new_requests']), + 'trusted': bool(untyped_validator['trusted']) + }) diff --git a/skale/contracts/manager/dkg.py b/skale/contracts/manager/dkg.py index 47d22c79..8e601310 100644 --- a/skale/contracts/manager/dkg.py +++ b/skale/contracts/manager/dkg.py @@ -17,45 +17,47 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.tools import retry_tx -from skale.utils.helper import split_public_key - - -class KeyShare: - def __init__(self, public_key: str, share: bytes): - self.public_key = split_public_key(public_key) - self.share = share - self.tuple = (self.public_key, self.share) +from typing import List, Tuple +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction - -class G2Point: - def __init__(self, xa, xb, ya, yb): - self.x = (xa, xb) - self.y = (ya, yb) - self.tuple = (self.x, self.y) +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract +from skale.transactions.tools import retry_tx +from skale.types.dkg import G2Point, KeyShare, VerificationVector +from skale.types.node import NodeId +from skale.types.schain import SchainHash -class DKG(BaseContract): +class DKG(SkaleManagerContract): @retry_tx @transaction_method - def broadcast(self, group_index, node_index, - verification_vector, secret_key_contribution, rotation_id): - return self.contract.functions.broadcast(group_index, node_index, - verification_vector, - secret_key_contribution, - rotation_id) + def broadcast( + self, + group_index: SchainHash, + node_index: NodeId, + verification_vector: VerificationVector, + secret_key_contribution: List[KeyShare], + rotation_id: int + ) -> ContractFunction: + return self.contract.functions.broadcast( + group_index, + node_index, + verification_vector, + secret_key_contribution, + rotation_id + ) @retry_tx @transaction_method def pre_response( self, - group_index: str, - from_node_index: int, - verification_vector: list, - verification_vector_mult: list, - secret_key_contribution: list - ): + group_index: SchainHash, + from_node_index: NodeId, + verification_vector: VerificationVector, + verification_vector_mult: VerificationVector, + secret_key_contribution: List[KeyShare] + ) -> ContractFunction: return self.contract.functions.preResponse( group_index, fromNodeIndex=from_node_index, @@ -68,11 +70,11 @@ def pre_response( @transaction_method def response( self, - group_index: bytes, - from_node_index: int, + group_index: SchainHash, + from_node_index: NodeId, secret_number: int, multiplied_share: G2Point - ): + ) -> ContractFunction: return self.contract.functions.response( group_index, fromNodeIndex=from_node_index, @@ -82,78 +84,118 @@ def response( @retry_tx @transaction_method - def alright(self, group_index, from_node_index): + def alright(self, group_index: SchainHash, from_node_index: NodeId) -> ContractFunction: return self.contract.functions.alright(group_index, from_node_index) @retry_tx @transaction_method - def complaint(self, group_index, from_node_index, to_node_index): + def complaint( + self, + group_index: SchainHash, + from_node_index: NodeId, + to_node_index: NodeId + ) -> ContractFunction: return self.contract.functions.complaint(group_index, from_node_index, to_node_index) @retry_tx @transaction_method - def complaint_bad_data(self, group_index, from_node_index, to_node_index): + def complaint_bad_data( + self, + group_index: SchainHash, + from_node_index: NodeId, + to_node_index: NodeId + ) -> ContractFunction: return self.contract.functions.complaintBadData(group_index, from_node_index, to_node_index) - def is_last_dkg_successful(self, group_index): - return self.contract.functions.isLastDKGSuccessful(group_index).call() + def is_last_dkg_successful(self, group_index: SchainHash) -> bool: + return bool(self.contract.functions.isLastDKGSuccessful(group_index).call()) - def is_channel_opened(self, group_index): - return self.contract.functions.isChannelOpened(group_index).call() + def is_channel_opened(self, group_index: SchainHash) -> bool: + return bool(self.contract.functions.isChannelOpened(group_index).call()) - def is_broadcast_possible(self, group_index, node_id, address): - return self.contract.functions.isBroadcastPossible(group_index, node_id).call( + def is_broadcast_possible( + self, + group_index: SchainHash, + node_id: NodeId, + address: ChecksumAddress + ) -> bool: + return bool(self.contract.functions.isBroadcastPossible(group_index, node_id).call( {'from': address} - ) + )) - def is_alright_possible(self, group_index, node_id, address): - return self.contract.functions.isAlrightPossible(group_index, node_id).call( + def is_alright_possible( + self, + group_index: SchainHash, + node_id: NodeId, + address: ChecksumAddress + ) -> bool: + return bool(self.contract.functions.isAlrightPossible(group_index, node_id).call( {'from': address} - ) + )) - def is_complaint_possible(self, group_index, node_from, node_to, address): - return self.contract.functions.isComplaintPossible(group_index, node_from, node_to).call( - {'from': address} + def is_complaint_possible( + self, + group_index: SchainHash, + node_from: NodeId, + node_to: NodeId, + address: ChecksumAddress + ) -> bool: + return bool( + self.contract.functions.isComplaintPossible( + group_index, + node_from, + node_to + ).call({'from': address}) ) - def is_pre_response_possible(self, group_index, node_id, address): - return self.contract.functions.isPreResponsePossible(group_index, node_id).call( + def is_pre_response_possible( + self, + group_index: SchainHash, + node_id: NodeId, + address: ChecksumAddress + ) -> bool: + return bool(self.contract.functions.isPreResponsePossible(group_index, node_id).call( {'from': address} - ) + )) - def is_response_possible(self, group_index, node_id, address): - return self.contract.functions.isResponsePossible(group_index, node_id).call( + def is_response_possible( + self, + group_index: SchainHash, + node_id: NodeId, + address: ChecksumAddress + ) -> bool: + return bool(self.contract.functions.isResponsePossible(group_index, node_id).call( {'from': address} - ) + )) - def is_all_data_received(self, group_index, node_from): - return self.contract.functions.isAllDataReceived(group_index, node_from).call() + def is_all_data_received(self, group_index: SchainHash, node_from: NodeId) -> bool: + return bool(self.contract.functions.isAllDataReceived(group_index, node_from).call()) - def is_everyone_broadcasted(self, group_index, address): - return self.contract.functions.isEveryoneBroadcasted(group_index).call( + def is_everyone_broadcasted(self, group_index: SchainHash, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.isEveryoneBroadcasted(group_index).call( {'from': address} - ) + )) - def get_number_of_completed(self, group_index): - return self.contract.functions.getNumberOfCompleted(group_index).call() + def get_number_of_completed(self, group_index: SchainHash) -> int: + return int(self.contract.functions.getNumberOfCompleted(group_index).call()) - def get_channel_started_time(self, group_index): - return self.contract.functions.getChannelStartedTime(group_index).call() + def get_channel_started_time(self, group_index: SchainHash) -> int: + return int(self.contract.functions.getChannelStartedTime(group_index).call()) - def get_complaint_started_time(self, group_index): - return self.contract.functions.getComplaintStartedTime(group_index).call() + def get_complaint_started_time(self, group_index: SchainHash) -> int: + return int(self.contract.functions.getComplaintStartedTime(group_index).call()) - def get_alright_started_time(self, group_index): - return self.contract.functions.getAlrightStartedTime(group_index).call() + def get_alright_started_time(self, group_index: SchainHash) -> int: + return int(self.contract.functions.getAlrightStartedTime(group_index).call()) - def get_complaint_data(self, group_index): - return self.contract.functions.getComplaintData(group_index).call() + def get_complaint_data(self, group_index: SchainHash) -> Tuple[NodeId, NodeId]: + return tuple(self.contract.functions.getComplaintData(group_index).call()) - def get_time_of_last_successful_dkg(self, group_index): - return self.contract.functions.getTimeOfLastSuccessfulDKG(group_index).call() + def get_time_of_last_successful_dkg(self, group_index: SchainHash) -> int: + return int(self.contract.functions.getTimeOfLastSuccessfulDKG(group_index).call()) - def is_node_broadcasted(self, group_index: int, node_id: int) -> bool: - return self.contract.functions.isNodeBroadcasted(group_index, node_id).call() + def is_node_broadcasted(self, group_index: SchainHash, node_id: NodeId) -> bool: + return bool(self.contract.functions.isNodeBroadcasted(group_index, node_id).call()) diff --git a/skale/contracts/manager/groups.py b/skale/contracts/manager/groups.py index b67d8735..2adb43cf 100644 --- a/skale/contracts/manager/groups.py +++ b/skale/contracts/manager/groups.py @@ -18,8 +18,8 @@ # along with SKALE.py. If not, see . """ SKALE group class """ -from skale.contracts.base_contract import BaseContract +from skale.contracts.skale_manager_contract import SkaleManagerContract -class Groups(BaseContract): +class Groups(SkaleManagerContract): pass diff --git a/skale/contracts/manager/key_storage.py b/skale/contracts/manager/key_storage.py index 353a35b8..575655ae 100644 --- a/skale/contracts/manager/key_storage.py +++ b/skale/contracts/manager/key_storage.py @@ -17,15 +17,22 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.base_contract import BaseContract +from typing import List +from skale.types.dkg import G2Point +from skale.contracts.skale_manager_contract import SkaleManagerContract +from skale.types.schain import SchainHash -class KeyStorage(BaseContract): - def get_common_public_key(self, group_index): - return self.contract.functions.getCommonPublicKey(group_index).call() +class KeyStorage(SkaleManagerContract): + def get_common_public_key(self, schain_hash: SchainHash) -> G2Point: + return G2Point(*self.contract.functions.getCommonPublicKey(schain_hash).call()) - def get_previous_public_key(self, group_index): - return self.contract.functions.getPreviousPublicKey(group_index).call() + def get_previous_public_key(self, schain_hash: SchainHash) -> G2Point: + return G2Point(*self.contract.functions.getPreviousPublicKey(schain_hash).call()) - def get_all_previous_public_keys(self, group_index): - return self.contract.functions.getAllPreviousPublicKeys(group_index).call() + def get_all_previous_public_keys(self, schain_hash: SchainHash) -> List[G2Point]: + return [ + G2Point(*key) + for key + in self.contract.functions.getAllPreviousPublicKeys(schain_hash).call() + ] diff --git a/skale/contracts/manager/manager.py b/skale/contracts/manager/manager.py index e3a844f8..5c364fa9 100644 --- a/skale/contracts/manager/manager.py +++ b/skale/contracts/manager/manager.py @@ -22,9 +22,14 @@ import logging import socket -from eth_abi import encode - -from skale.contracts.base_contract import BaseContract, transaction_method +from eth_abi.abi import encode +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction + +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract +from skale.types.node import NodeId, Port +from skale.types.schain import SchainName from skale.utils import helper from skale.transactions.result import TxRes from skale.dataclasses.schain_options import ( @@ -34,10 +39,17 @@ logger = logging.getLogger(__name__) -class Manager(BaseContract): +class Manager(SkaleManagerContract): @transaction_method - def create_node(self, ip, port, name, domain_name, public_ip=None): + def create_node( + self, + ip: str, + port: Port, + name: str, + domain_name: str, + public_ip: str | None = None + ) -> ContractFunction: logger.info( f'create_node: {ip}:{port}, name: {name}, domain_name: {domain_name}') skale_nonce = helper.generate_nonce() @@ -56,7 +68,7 @@ def create_node(self, ip, port, name, domain_name, public_ip=None): domain_name ) - def create_default_schain(self, name): + def create_default_schain(self, name: SchainName) -> TxRes: lifetime = 3600 nodes_type = self.skale.schains_internal.number_of_schain_types() price_in_wei = self.skale.schains.get_schain_price( @@ -70,10 +82,10 @@ def create_schain( lifetime: int, type_of_nodes: int, deposit: str, - name: str, - schain_originator: str = None, - options: SchainOptions = None - ): + name: SchainName, + schain_originator: ChecksumAddress | None = None, + options: SchainOptions | None = None + ) -> ContractFunction: logger.info( f'create_schain: type_of_nodes: {type_of_nodes}, name: {name}') skale_nonce = helper.generate_nonce() @@ -95,33 +107,33 @@ def create_schain( ) @transaction_method - def get_bounty(self, node_id): + def get_bounty(self, node_id: NodeId) -> ContractFunction: return self.contract.functions.getBounty(node_id) @transaction_method - def delete_schain(self, schain_name): + def delete_schain(self, schain_name: SchainName) -> ContractFunction: return self.contract.functions.deleteSchain(schain_name) @transaction_method - def delete_schain_by_root(self, schain_name): + def delete_schain_by_root(self, schain_name: SchainName) -> ContractFunction: return self.contract.functions.deleteSchainByRoot(schain_name) @transaction_method - def node_exit(self, node_id): + def node_exit(self, node_id: NodeId) -> ContractFunction: return self.contract.functions.nodeExit(node_id) @transaction_method - def grant_role(self, role: bytes, address: str) -> TxRes: + def grant_role(self, role: bytes, address: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, address) def default_admin_role(self) -> bytes: - return self.contract.functions.DEFAULT_ADMIN_ROLE().call() + return bytes(self.contract.functions.DEFAULT_ADMIN_ROLE().call()) def admin_role(self) -> bytes: - return self.contract.functions.ADMIN_ROLE().call() + return bytes(self.contract.functions.ADMIN_ROLE().call()) def schain_removal_role(self) -> bytes: - return self.contract.functions.SCHAIN_REMOVAL_ROLE().call() + return bytes(self.contract.functions.SCHAIN_REMOVAL_ROLE().call()) - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) diff --git a/skale/contracts/manager/node_rotation.py b/skale/contracts/manager/node_rotation.py index 8e3a9a9f..66db202e 100644 --- a/skale/contracts/manager/node_rotation.py +++ b/skale/contracts/manager/node_rotation.py @@ -18,83 +18,76 @@ # along with SKALE.py. If not, see . """ NodeRotation.sol functions """ +from __future__ import annotations import logging import functools -import warnings -from dataclasses import dataclass +from typing import TYPE_CHECKING, List -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from eth_typing import ChecksumAddress + +from skale.contracts.base_contract import transaction_method +from web3.contract.contract import ContractFunction from web3.exceptions import ContractLogicError +from skale.contracts.skale_manager_contract import SkaleManagerContract +from skale.types.node import NodeId +from skale.types.rotation import Rotation, RotationSwap +from skale.types.schain import SchainHash, SchainName -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from skale.contracts.manager.schains import SChains -NO_PREVIOUS_NODE_EXCEPTION_TEXT = 'No previous node' +logger = logging.getLogger(__name__) -@dataclass -class Rotation: - leaving_node_id: int - new_node_id: int - freeze_until: int - rotation_counter: int +NO_PREVIOUS_NODE_EXCEPTION_TEXT = 'No previous node' -class NodeRotation(BaseContract): +class NodeRotation(SkaleManagerContract): """Wrapper for NodeRotation.sol functions""" @property @functools.lru_cache() - def schains(self): + def schains(self) -> SChains: return self.skale.schains - def get_rotation_obj(self, schain_name) -> Rotation: + def get_rotation(self, schain_name: SchainName) -> Rotation: schain_id = self.schains.name_to_id(schain_name) rotation_data = self.contract.functions.getRotation(schain_id).call() return Rotation(*rotation_data) - def get_rotation(self, schain_name): - warnings.warn('Deprecated, will be removed in v6', DeprecationWarning) - schain_id = self.schains.name_to_id(schain_name) - rotation_data = self.contract.functions.getRotation(schain_id).call() - return { - 'leaving_node': rotation_data[0], - 'new_node': rotation_data[1], - 'freeze_until': rotation_data[2], - 'rotation_id': rotation_data[3] - } - - def get_leaving_history(self, node_id): + def get_leaving_history(self, node_id: NodeId) -> List[RotationSwap]: raw_history = self.contract.functions.getLeavingHistory(node_id).call() history = [ - { - 'schain_id': schain[0], - 'finished_rotation': schain[1] - } + RotationSwap({ + 'schain_id': SchainHash(schain[0]), + 'finished_rotation': int(schain[1]) + }) for schain in raw_history ] return history - def get_schain_finish_ts(self, node_id: int, schain_name: str) -> int: - raw_history = self.contract.functions.getLeavingHistory(node_id).call() + def get_schain_finish_ts(self, node_id: NodeId, schain_name: SchainName) -> int | None: + history = self.get_leaving_history(node_id) schain_id = self.skale.schains.name_to_id(schain_name) finish_ts = next( - (schain[1] for schain in raw_history if '0x' + schain[0].hex() == schain_id), None) + (swap['finished_rotation'] for swap in history if swap['schain_id'] == schain_id), + None + ) if not finish_ts: return None - return finish_ts + return int(finish_ts) - def is_rotation_in_progress(self, schain_name) -> bool: + def is_rotation_in_progress(self, schain_name: SchainName) -> bool: schain_id = self.schains.name_to_id(schain_name) - return self.contract.functions.isRotationInProgress(schain_id).call() + return bool(self.contract.functions.isRotationInProgress(schain_id).call()) - def is_new_node_found(self, schain_name) -> bool: + def is_new_node_found(self, schain_name: SchainName) -> bool: schain_id = self.schains.name_to_id(schain_name) - return self.contract.functions.isNewNodeFound(schain_id).call() + return bool(self.contract.functions.isNewNodeFound(schain_id).call()) - def is_rotation_active(self, schain_name) -> bool: + def is_rotation_active(self, schain_name: SchainName) -> bool: """ The public function that tells whether rotation is in the active phase - the new group is already generated @@ -102,8 +95,8 @@ def is_rotation_active(self, schain_name) -> bool: finish_ts_reached = self.is_finish_ts_reached(schain_name) return self.is_rotation_in_progress(schain_name) and not finish_ts_reached - def is_finish_ts_reached(self, schain_name) -> bool: - rotation = self.skale.node_rotation.get_rotation_obj(schain_name) + def is_finish_ts_reached(self, schain_name: SchainName) -> bool: + rotation = self.skale.node_rotation.get_rotation(schain_name) schain_finish_ts = self.get_schain_finish_ts(rotation.leaving_node_id, schain_name) if not schain_finish_ts: @@ -115,24 +108,24 @@ def is_finish_ts_reached(self, schain_name) -> bool: logger.info(f'current_ts: {current_ts}, schain_finish_ts: {schain_finish_ts}') return current_ts > schain_finish_ts - def wait_for_new_node(self, schain_name): + def wait_for_new_node(self, schain_name: SchainName) -> bool: schain_id = self.schains.name_to_id(schain_name) - return self.contract.functions.waitForNewNode(schain_id).call() + return bool(self.contract.functions.waitForNewNode(schain_id).call()) @transaction_method - def grant_role(self, role: bytes, owner: str) -> TxRes: + def grant_role(self, role: bytes, owner: str) -> ContractFunction: return self.contract.functions.grantRole(role, owner) - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) - def debugger_role(self): - return self.contract.functions.DEBUGGER_ROLE().call() + def debugger_role(self) -> bytes: + return bytes(self.contract.functions.DEBUGGER_ROLE().call()) - def get_previous_node(self, schain_name: str, node_id: int) -> int: + def get_previous_node(self, schain_name: SchainName, node_id: NodeId) -> NodeId | None: schain_id = self.schains.name_to_id(schain_name) try: - return self.contract.functions.getPreviousNode(schain_id, node_id).call() + return NodeId(self.contract.functions.getPreviousNode(schain_id, node_id).call()) except (ContractLogicError, ValueError) as e: if NO_PREVIOUS_NODE_EXCEPTION_TEXT in str(e): return None diff --git a/skale/contracts/manager/nodes.py b/skale/contracts/manager/nodes.py index 86af9360..2ca59a0f 100644 --- a/skale/contracts/manager/nodes.py +++ b/skale/contracts/manager/nodes.py @@ -19,14 +19,18 @@ """ Nodes.sol functions """ import socket -from enum import IntEnum +from typing import Any, Dict, List, Tuple, cast from Crypto.Hash import keccak +from eth_typing import BlockNumber, ChecksumAddress +from web3.contract.contract import ContractFunction from web3.exceptions import BadFunctionCallOutput, ContractLogicError -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract +from skale.types.node import Node, NodeId, NodeStatus, Port +from skale.types.validator import ValidatorId from skale.utils.exceptions import InvalidNodeIdError from skale.utils.helper import format_fields @@ -36,141 +40,169 @@ ] -class NodeStatus(IntEnum): - ACTIVE = 0 - LEAVING = 1 - LEFT = 2 - IN_MAINTENANCE = 3 - - -class Nodes(BaseContract): - def __get_raw(self, node_id): +class Nodes(SkaleManagerContract): + def __get_raw(self, node_id: NodeId) -> List[Any]: try: - return self.contract.functions.nodes(node_id).call() + return list(self.contract.functions.nodes(node_id).call()) except (ContractLogicError, ValueError, BadFunctionCallOutput): raise InvalidNodeIdError(node_id) - def __get_raw_w_pk(self, node_id): + def __get_raw_w_pk(self, node_id: NodeId) -> List[Any]: raw_node_struct = self.__get_raw(node_id) raw_node_struct.append(self.get_node_public_key(node_id)) return raw_node_struct - def __get_raw_w_pk_w_domain(self, node_id): + def __get_raw_w_pk_w_domain(self, node_id: NodeId) -> List[Any]: raw_node_struct_w_pk = self.__get_raw_w_pk(node_id) raw_node_struct_w_pk.append(self.get_domain_name(node_id)) return raw_node_struct_w_pk @format_fields(FIELDS) - def get(self, node_id): + def untyped_get(self, node_id: NodeId) -> List[Any]: return self.__get_raw_w_pk_w_domain(node_id) + def get(self, node_id: NodeId) -> Node: + node = self.untyped_get(node_id) + if node is None: + raise ValueError('Node with id ', node_id, ' is not found') + if isinstance(node, dict): + return self._to_node(node) + if isinstance(node, list): + return self._to_node(node[0]) + raise ValueError("Can't process returned node type") + @format_fields(FIELDS) - def get_by_name(self, name): + def get_by_name(self, name: str) -> List[Any]: name_hash = self.name_to_id(name) _id = self.contract.functions.nodesNameToIndex(name_hash).call() return self.__get_raw_w_pk_w_domain(_id) - def get_nodes_number(self): - return self.contract.functions.getNumberOfNodes().call() + def get_nodes_number(self) -> int: + return int(self.contract.functions.getNumberOfNodes().call()) - def get_active_node_ids(self): + def get_active_node_ids(self) -> List[NodeId]: nodes_number = self.get_nodes_number() return [ - node_id + NodeId(node_id) for node_id in range(0, nodes_number) - if self.get_node_status(node_id) == NodeStatus.ACTIVE + if self.get_node_status(NodeId(node_id)) == NodeStatus.ACTIVE ] - def get_active_node_ips(self): + def get_active_node_ips(self) -> List[bytes]: nodes_number = self.get_nodes_number() return [ - self.get(node_id)['ip'] + self.get(NodeId(node_id))['ip'] for node_id in range(0, nodes_number) - if self.get_node_status(node_id) == NodeStatus.ACTIVE + if self.get_node_status(NodeId(node_id)) == NodeStatus.ACTIVE ] - def name_to_id(self, name): + def name_to_id(self, name: str) -> bytes: keccak_hash = keccak.new(data=name.encode("utf8"), digest_bits=256) return keccak_hash.digest() - def is_node_name_available(self, name): + def is_node_name_available(self, name: str) -> bool: node_id = self.name_to_id(name) return not self.contract.functions.nodesNameCheck(node_id).call() - def is_node_ip_available(self, ip): + def is_node_ip_available(self, ip: str) -> bool: ip_bytes = socket.inet_aton(ip) return not self.contract.functions.nodesIPCheck(ip_bytes).call() - def node_name_to_index(self, name): + def node_name_to_index(self, name: str) -> int: name_hash = self.name_to_id(name) - return self.contract.functions.nodesNameToIndex(name_hash).call() + return int(self.contract.functions.nodesNameToIndex(name_hash).call()) - def get_node_status(self, node_id): + def get_node_status(self, node_id: NodeId) -> NodeStatus: try: - return self.contract.functions.getNodeStatus(node_id).call() + return NodeStatus(self.contract.functions.getNodeStatus(node_id).call()) except (ContractLogicError, ValueError, BadFunctionCallOutput): raise InvalidNodeIdError(node_id) - def get_node_finish_time(self, node_id): + def get_node_finish_time(self, node_id: NodeId) -> int: try: - return self.contract.functions.getNodeFinishTime(node_id).call() + return int(self.contract.functions.getNodeFinishTime(node_id).call()) except (ContractLogicError, ValueError, BadFunctionCallOutput): raise InvalidNodeIdError(node_id) - def __get_node_public_key_raw(self, node_id): + def __get_node_public_key_raw(self, node_id: NodeId) -> Tuple[bytes, bytes]: try: - return self.contract.functions.getNodePublicKey(node_id).call() + return cast( + Tuple[bytes, bytes], + self.contract.functions.getNodePublicKey(node_id).call() + ) except (ContractLogicError, ValueError, BadFunctionCallOutput): raise InvalidNodeIdError(node_id) - def get_node_public_key(self, node_id): + def get_node_public_key(self, node_id: NodeId) -> str: raw_key = self.__get_node_public_key_raw(node_id) key_bytes = raw_key[0] + raw_key[1] return self.skale.web3.to_hex(key_bytes) - def get_validator_node_indices(self, validator_id: int) -> list: + def get_validator_node_indices(self, validator_id: int) -> list[NodeId]: """Returns list of node indices to the validator :returns: List of trusted node indices :rtype: list """ - return self.contract.functions.getValidatorNodeIndexes(validator_id).call() + return [ + NodeId(id) + for id + in self.contract.functions.getValidatorNodeIndexes(validator_id).call() + ] - def get_last_change_ip_time(self, node_id: int) -> list: - return self.contract.functions.getLastChangeIpTime(node_id).call() + def get_last_change_ip_time(self, node_id: NodeId) -> int: + return int(self.contract.functions.getLastChangeIpTime(node_id).call()) @transaction_method - def set_node_in_maintenance(self, node_id): + def set_node_in_maintenance(self, node_id: NodeId) -> ContractFunction: return self.contract.functions.setNodeInMaintenance(node_id) @transaction_method - def remove_node_from_in_maintenance(self, node_id): + def remove_node_from_in_maintenance(self, node_id: NodeId) -> ContractFunction: return self.contract.functions.removeNodeFromInMaintenance(node_id) @transaction_method - def set_domain_name(self, node_id: int, domain_name: str): + def set_domain_name(self, node_id: NodeId, domain_name: str) -> ContractFunction: return self.contract.functions.setDomainName(node_id, domain_name) - def get_domain_name(self, node_id: int): - return self.contract.functions.getNodeDomainName(node_id).call() + def get_domain_name(self, node_id: NodeId) -> str: + return str(self.contract.functions.getNodeDomainName(node_id).call()) @transaction_method - def grant_role(self, role: bytes, owner: str) -> TxRes: + def grant_role(self, role: bytes, owner: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, owner) - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) - def node_manager_role(self): - return self.contract.functions.NODE_MANAGER_ROLE().call() + def node_manager_role(self) -> bytes: + return bytes(self.contract.functions.NODE_MANAGER_ROLE().call()) - def compliance_role(self): - return self.contract.functions.COMPLIANCE_ROLE().call() + def compliance_role(self) -> bytes: + return bytes(self.contract.functions.COMPLIANCE_ROLE().call()) @transaction_method - def init_exit(self, node_id: int) -> TxRes: + def init_exit(self, node_id: NodeId) -> ContractFunction: return self.contract.functions.initExit(node_id) @transaction_method - def change_ip(self, node_id: int, ip: bytes, public_ip: bytes) -> TxRes: + def change_ip(self, node_id: NodeId, ip: bytes, public_ip: bytes) -> ContractFunction: return self.contract.functions.changeIP(node_id, ip, public_ip) + + def _to_node(self, untyped_node: Dict[str, Any]) -> Node: + for key in Node.__annotations__: + if key not in untyped_node: + raise ValueError(f"Key: {key} is not available in node.") + return Node({ + 'name': str(untyped_node['name']), + 'ip': bytes(untyped_node['ip']), + 'publicIP': bytes(untyped_node['publicIP']), + 'port': Port(untyped_node['port']), + 'start_block': BlockNumber(untyped_node['start_block']), + 'last_reward_date': int(untyped_node['last_reward_date']), + 'finish_time': int(untyped_node['finish_time']), + 'status': NodeStatus(untyped_node['status']), + 'validator_id': ValidatorId(untyped_node['validator_id']), + 'publicKey': str(untyped_node['publicKey']), + 'domain_name': str(untyped_node['domain_name']), + }) diff --git a/skale/contracts/manager/punisher.py b/skale/contracts/manager/punisher.py index bfa14a3d..f65df93b 100644 --- a/skale/contracts/manager/punisher.py +++ b/skale/contracts/manager/punisher.py @@ -17,17 +17,20 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract -class Punisher(BaseContract): + +class Punisher(SkaleManagerContract): @transaction_method - def grant_role(self, role: bytes, owner: str) -> TxRes: + def grant_role(self, role: bytes, owner: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, owner) - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) - def forgiver_role(self): - return self.contract.functions.FORGIVER_ROLE().call() + def forgiver_role(self) -> bytes: + return bytes(self.contract.functions.FORGIVER_ROLE().call()) diff --git a/skale/contracts/manager/schains.py b/skale/contracts/manager/schains.py index e12d18cd..fca48ba9 100644 --- a/skale/contracts/manager/schains.py +++ b/skale/contracts/manager/schains.py @@ -19,13 +19,22 @@ """ Schains.sol functions """ import functools -from dataclasses import dataclass, asdict +from dataclasses import asdict +from typing import Any, List from Crypto.Hash import keccak - -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes -from skale.utils.helper import format_fields +from eth_typing import ChecksumAddress, HexStr +from hexbytes import HexBytes +from web3 import Web3 +from web3.contract.contract import ContractFunction +from web3.types import Wei + +from skale.contracts.base_contract import transaction_method +from skale.contracts.manager.node_rotation import NodeRotation +from skale.contracts.manager.schains_internal import SChainsInternal +from skale.contracts.skale_manager_contract import SkaleManagerContract +from skale.types.node import NodeId +from skale.types.schain import SchainHash, SchainName, SchainStructure, SchainStructureWithStatus from skale.dataclasses.schain_options import ( SchainOptions, get_default_schain_options, parse_schain_options ) @@ -33,63 +42,36 @@ FIELDS = [ 'name', 'mainnetOwner', 'indexInOwnerList', 'partOfNode', 'lifetime', 'startDate', 'startBlock', - 'deposit', 'index', 'generation', 'originator', 'chainId', 'multitransactionMode', - 'thresholdEncryption' + 'deposit', 'index', 'generation', 'originator', 'chainId', 'options' ] -@dataclass -class SchainStructure: - name: str - mainnet_owner: str - index_in_owner_list: int - part_of_node: int - lifetime: int - start_date: int - start_block: int - deposit: int - index: int - generation: int - originator: str - chain_id: int - options: SchainOptions - - -class SChains(BaseContract): +class SChains(SkaleManagerContract): """Wrapper for some of the Schains.sol functions""" - def name_to_group_id(self, name): + def name_to_group_id(self, name: SchainName) -> HexBytes: return self.skale.web3.keccak(text=name) @property @functools.lru_cache() - def schains_internal(self): + def schains_internal(self) -> SChainsInternal: return self.skale.schains_internal @property @functools.lru_cache() - def node_rotation(self): + def node_rotation(self) -> NodeRotation: return self.skale.node_rotation - @format_fields(FIELDS) - def get(self, id_, obj=False): + def get(self, id_: SchainHash) -> SchainStructure: res = self.schains_internal.get_raw(id_) - hash_obj = keccak.new(data=res[0].encode("utf8"), digest_bits=256) - hash_str = "0x" + hash_obj.hexdigest()[:13] - res.append(hash_str) options = self.get_options(id_) - if obj: # TODO: temporary solution for backwards compatibility - return SchainStructure(*res, options=options) - else: - res += asdict(options).values() - return res - - @format_fields(FIELDS) - def get_by_name(self, name, obj=False): + return SchainStructure(**asdict(res), chainId=self.name_to_id(res.name), options=options) + + def get_by_name(self, name: SchainName) -> SchainStructure: id_ = self.name_to_id(name) - return self.get(id_, obj=obj) + return self.get(id_) - def get_schains_for_owner(self, account): + def get_schains_for_owner(self, account: ChecksumAddress) -> List[SchainStructure]: schains = [] list_size = self.schains_internal.get_schain_list_size(account) @@ -99,40 +81,48 @@ def get_schains_for_owner(self, account): schains.append(schain) return schains - def get_schains_for_node(self, node_id): + def get_schains_for_node(self, node_id: NodeId) -> list[SchainStructureWithStatus]: schains = [] schain_ids = self.schains_internal.get_schain_ids_for_node(node_id) for schain_id in schain_ids: - schain = self.get(schain_id) - schain['active'] = True if self.schain_active(schain) else False + simple_schain = self.get(schain_id) + schain = SchainStructureWithStatus( + **asdict(simple_schain), + active=self.schain_active(simple_schain) + ) schains.append(schain) return schains - def get_active_schains_for_node(self, node_id): + def get_active_schains_for_node(self, node_id: NodeId) -> List[SchainStructureWithStatus]: schains = [] schain_ids = self.schains_internal.get_active_schain_ids_for_node(node_id) for schain_id in schain_ids: - schain = self.get(schain_id) - schain['active'] = True + simple_schain = self.get(schain_id) + schain = SchainStructureWithStatus( + **asdict(simple_schain), + active=True + ) schains.append(schain) return schains - def name_to_id(self, name): + def name_to_id(self, name: SchainName) -> SchainHash: keccak_hash = keccak.new(data=name.encode("utf8"), digest_bits=256) - return '0x' + keccak_hash.hexdigest() + return SchainHash(Web3.to_bytes(hexstr=Web3.to_hex(hexstr=HexStr(keccak_hash.hexdigest())))) - def get_last_rotation_id(self, schain_name): + def get_last_rotation_id(self, schain_name: SchainName) -> int: rotation_data = self.node_rotation.get_rotation(schain_name) - return rotation_data['rotation_id'] + return rotation_data.rotation_counter - def schain_active(self, schain): - if schain['name'] != '' and \ - schain['mainnetOwner'] != '0x0000000000000000000000000000000000000000': + def schain_active(self, schain: SchainStructure) -> bool: + if schain.name != '' and \ + schain.mainnetOwner != '0x0000000000000000000000000000000000000000': return True + return False - def get_schain_price(self, index_of_type, lifetime): - return self.contract.functions.getSchainPrice(index_of_type, - lifetime).call() + def get_schain_price(self, index_of_type: int, lifetime: int) -> Wei: + return Wei( + self.contract.functions.getSchainPrice(index_of_type, lifetime).call() + ) @transaction_method def add_schain_by_foundation( @@ -140,11 +130,11 @@ def add_schain_by_foundation( lifetime: int, type_of_nodes: int, nonce: int, - name: str, - options: SchainOptions = None, - schain_owner=None, - schain_originator=None - ) -> TxRes: + name: SchainName, + options: SchainOptions | None = None, + schain_owner: ChecksumAddress | None = None, + schain_originator: ChecksumAddress | None = None + ) -> ContractFunction: if schain_owner is None: schain_owner = self.skale.wallet.address if schain_originator is None: @@ -163,23 +153,23 @@ def add_schain_by_foundation( ) @transaction_method - def grant_role(self, role: bytes, owner: str) -> TxRes: + def grant_role(self, role: bytes, owner: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, owner) - def schain_creator_role(self): - return self.contract.functions.SCHAIN_CREATOR_ROLE().call() + def schain_creator_role(self) -> bytes: + return bytes(self.contract.functions.SCHAIN_CREATOR_ROLE().call()) - def __raw_get_options(self, schain_id: str) -> list: - return self.contract.functions.getOptions(schain_id).call() + def __raw_get_options(self, schain_id: SchainHash) -> List[Any]: + return list(self.contract.functions.getOptions(schain_id).call()) - def get_options(self, schain_id: str) -> SchainOptions: + def get_options(self, schain_id: SchainHash) -> SchainOptions: return parse_schain_options( raw_options=self.__raw_get_options(schain_id) ) - def get_options_by_name(self, name: str) -> SchainOptions: + def get_options_by_name(self, name: SchainName) -> SchainOptions: id_ = self.name_to_id(name) return self.get_options(id_) - def restart_schain_creation(self, name: str) -> TxRes: + def restart_schain_creation(self, name: SchainName) -> ContractFunction: return self.contract.functions.restartSchainCreation(name) diff --git a/skale/contracts/manager/schains_internal.py b/skale/contracts/manager/schains_internal.py index e91d47ff..d3e4d1e2 100644 --- a/skale/contracts/manager/schains_internal.py +++ b/skale/contracts/manager/schains_internal.py @@ -18,82 +18,109 @@ # along with SKALE.py. If not, see . """ SchainsInternal.sol functions """ +from __future__ import annotations import functools -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from typing import TYPE_CHECKING, List +from eth_typing import ChecksumAddress -class SChainsInternal(BaseContract): +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract +from skale.types.node import NodeId +from skale.types.schain import Schain, SchainHash, SchainName + +if TYPE_CHECKING: + from web3.contract.contract import ContractFunction + from skale.contracts.manager.schains import SChains + + +class SChainsInternal(SkaleManagerContract): """Wrapper for some of the SchainsInternal.sol functions""" @property @functools.lru_cache() - def schains(self): + def schains(self) -> SChains: return self.skale.schains - def get_raw(self, name): - return self.contract.functions.schains(name).call() + def get_raw(self, name: SchainHash) -> Schain: + return Schain(*self.contract.functions.schains(name).call()) - def get_all_schains_ids(self): - return self.contract.functions.getSchains().call() + def get_all_schains_ids(self) -> List[SchainHash]: + return [ + SchainHash(schain_hash) + for schain_hash + in self.contract.functions.getSchains().call() + ] - def get_schains_number(self): - return self.contract.functions.numberOfSchains().call() + def get_schains_number(self) -> int: + return int(self.contract.functions.numberOfSchains().call()) - def get_schain_list_size(self, account): - return self.contract.functions.getSchainListSize(account).call( - {'from': account}) + def get_schain_list_size(self, account: ChecksumAddress) -> int: + return int(self.contract.functions.getSchainListSize(account).call( + {'from': account})) - def get_schain_id_by_index_for_owner(self, account, index): - return self.contract.functions.schainIndexes(account, index).call() + def get_schain_id_by_index_for_owner(self, account: ChecksumAddress, index: int) -> SchainHash: + return SchainHash(self.contract.functions.schainIndexes(account, index).call()) - def get_node_ids_for_schain(self, name): + def get_node_ids_for_schain(self, name: SchainName) -> List[NodeId]: id_ = self.schains.name_to_id(name) - return self.contract.functions.getNodesInGroup(id_).call() - - def get_schain_ids_for_node(self, node_id): - return self.contract.functions.getSchainHashesForNode(node_id).call() - - def is_schain_exist(self, name): + return [ + NodeId(node) + for node + in self.contract.functions.getNodesInGroup(id_).call() + ] + + def get_schain_ids_for_node(self, node_id: NodeId) -> List[SchainHash]: + return [ + SchainHash(schain) + for schain + in self.contract.functions.getSchainHashesForNode(node_id).call() + ] + + def is_schain_exist(self, name: SchainName) -> bool: id_ = self.schains.name_to_id(name) - return self.contract.functions.isSchainExist(id_).call() + return bool(self.contract.functions.isSchainExist(id_).call()) - def get_active_schain_ids_for_node(self, node_id): - return self.contract.functions.getActiveSchains(node_id).call() + def get_active_schain_ids_for_node(self, node_id: NodeId) -> List[SchainHash]: + return [ + SchainHash(schain) + for schain + in self.contract.functions.getActiveSchains(node_id).call() + ] def number_of_schain_types(self) -> int: - return self.contract.functions.numberOfSchainTypes().call() + return int(self.contract.functions.numberOfSchainTypes().call()) @transaction_method def add_schain_type( self, part_of_node: int, number_of_nodes: int - ) -> TxRes: + ) -> ContractFunction: return self.contract.functions.addSchainType( part_of_node, number_of_nodes) def current_generation(self) -> int: - return self.contract.functions.currentGeneration().call() + return int(self.contract.functions.currentGeneration().call()) @transaction_method - def grant_role(self, role: bytes, address: str) -> TxRes: + def grant_role(self, role: bytes, address: ChecksumAddress) -> ContractFunction: return self.contract.functions.grantRole(role, address) - def has_role(self, role: bytes, address: str) -> bool: - return self.contract.functions.hasRole(role, address).call() + def has_role(self, role: bytes, address: ChecksumAddress) -> bool: + return bool(self.contract.functions.hasRole(role, address).call()) def schain_type_manager_role(self) -> bytes: - return self.contract.functions.SCHAIN_TYPE_MANAGER_ROLE().call() + return bytes(self.contract.functions.SCHAIN_TYPE_MANAGER_ROLE().call()) - def debugger_role(self): - return self.contract.functions.DEBUGGER_ROLE().call() + def debugger_role(self) -> bytes: + return bytes(self.contract.functions.DEBUGGER_ROLE().call()) - def generation_manager_role(self): - return self.contract.functions.GENERATION_MANAGER_ROLE().call() + def generation_manager_role(self) -> bytes: + return bytes(self.contract.functions.GENERATION_MANAGER_ROLE().call()) @transaction_method - def new_generation(self) -> TxRes: + def new_generation(self) -> ContractFunction: return self.contract.functions.newGeneration() - def check_exception(self, schain_name: str, node_id: int) -> bool: + def check_exception(self, schain_name: SchainName, node_id: NodeId) -> bool: id_ = self.schains.name_to_id(schain_name) - return self.contract.functions.checkException(id_, node_id).call() + return bool(self.contract.functions.checkException(id_, node_id).call()) diff --git a/skale/contracts/manager/sync_manager.py b/skale/contracts/manager/sync_manager.py index 1b2c6f1f..16996c60 100644 --- a/skale/contracts/manager/sync_manager.py +++ b/skale/contracts/manager/sync_manager.py @@ -21,8 +21,10 @@ from collections import namedtuple from typing import List +from web3.contract.contract import ContractFunction -from skale.contracts.base_contract import BaseContract, transaction_method +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract from skale.transactions.result import TxRes from skale.utils.helper import ip_from_bytes, ip_to_bytes @@ -36,11 +38,11 @@ def from_packed(cls, packed_ips: List[bytes]) -> IpRange: ) -class SyncManager(BaseContract): +class SyncManager(SkaleManagerContract): """Wrapper for SyncManager.sol functions""" @transaction_method - def add_ip_range(self, name, start_ip: str, end_ip: str) -> TxRes: + def add_ip_range(self, name: str, start_ip: str, end_ip: str) -> ContractFunction: return self.contract.functions.addIPRange( name, ip_to_bytes(start_ip), @@ -48,11 +50,11 @@ def add_ip_range(self, name, start_ip: str, end_ip: str) -> TxRes: ) @transaction_method - def remove_ip_range(self, name: str) -> TxRes: + def remove_ip_range(self, name: str) -> ContractFunction: return self.contract.functions.removeIPRange(name) def get_ip_ranges_number(self) -> int: - return self.contract.functions.getIPRangesNumber().call() + return int(self.contract.functions.getIPRangesNumber().call()) def get_ip_range_by_index(self, index: int) -> IpRange: packed = self.contract.functions.getIPRangeByIndex(index).call() @@ -66,8 +68,8 @@ def grant_sync_manager_role(self, address: str) -> TxRes: return self.grant_role(self.sync_manager_role(), address) def sync_manager_role(self) -> bytes: - return self.contract.functions.SYNC_MANAGER_ROLE().call() + return bytes(self.contract.functions.SYNC_MANAGER_ROLE().call()) @transaction_method - def grant_role(self, role: bytes, owner: str) -> TxRes: + def grant_role(self, role: bytes, owner: str) -> ContractFunction: return self.contract.functions.grantRole(role, owner) diff --git a/skale/contracts/manager/test/time_helpers_with_debug.py b/skale/contracts/manager/test/time_helpers_with_debug.py index 6ffb701d..cc0cb795 100644 --- a/skale/contracts/manager/test/time_helpers_with_debug.py +++ b/skale/contracts/manager/test/time_helpers_with_debug.py @@ -17,15 +17,17 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from web3.contract.contract import ContractFunction +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract -class TimeHelpersWithDebug(BaseContract): + +class TimeHelpersWithDebug(SkaleManagerContract): """Wrapper for TimeHelpersWithDebug.sol functions (internal usage only)""" @transaction_method - def skip_time(self, sec: int) -> TxRes: + def skip_time(self, sec: int) -> ContractFunction: """Skip time on contracts :param sec: Time to skip in seconds @@ -41,4 +43,4 @@ def get_current_month(self) -> int: :returns: Month index :rtype: int """ - return self.contract.functions.getCurrentMonth().call() + return int(self.contract.functions.getCurrentMonth().call()) diff --git a/skale/contracts/manager/token.py b/skale/contracts/manager/token.py index a58a891d..648d8850 100644 --- a/skale/contracts/manager/token.py +++ b/skale/contracts/manager/token.py @@ -18,24 +18,35 @@ # along with SKALE.py. If not, see . """ SKALE token operations """ -from skale.contracts.base_contract import BaseContract, transaction_method +from eth_typing import ChecksumAddress +from web3.contract.contract import ContractFunction +from web3.types import Wei +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract -class Token(BaseContract): + +class Token(SkaleManagerContract): @transaction_method - def transfer(self, address, value): + def transfer(self, address: ChecksumAddress, value: Wei) -> ContractFunction: return self.contract.functions.send(address, value, b'') - def get_balance(self, address): - return self.contract.functions.balanceOf(address).call() + def get_balance(self, address: ChecksumAddress) -> Wei: + return Wei(self.contract.functions.balanceOf(address).call()) @transaction_method - def add_authorized(self, address, wallet): # pragma: no cover + def add_authorized(self, address: ChecksumAddress) -> ContractFunction: # pragma: no cover return self.contract.functions.addAuthorized(address) - def get_and_update_slashed_amount(self, address: str) -> int: - return self.contract.functions.getAndUpdateSlashedAmount(address).call() + def get_and_update_slashed_amount(self, address: ChecksumAddress) -> Wei: + return Wei(self.contract.functions.getAndUpdateSlashedAmount(address).call()) @transaction_method - def mint(self, address, amount, user_data=b'', operator_data=b''): + def mint( + self, + address: ChecksumAddress, + amount: Wei, + user_data: bytes = b'', + operator_data: bytes = b'' + ) -> ContractFunction: return self.contract.functions.mint(address, amount, user_data, operator_data) diff --git a/skale/contracts/manager/wallets.py b/skale/contracts/manager/wallets.py index 01a579e0..34c3d853 100644 --- a/skale/contracts/manager/wallets.py +++ b/skale/contracts/manager/wallets.py @@ -17,24 +17,26 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.base_contract import BaseContract, transaction_method -from skale.transactions.result import TxRes +from web3.contract.contract import ContractFunction +from skale.contracts.base_contract import transaction_method +from skale.contracts.skale_manager_contract import SkaleManagerContract -class Wallets(BaseContract): + +class Wallets(SkaleManagerContract): def get_validator_balance(self, validator_id: int) -> int: """Returns SRW balance by validator id (in wei). :returns: SRW balance (wei) :rtype: int """ - return self.contract.functions.getValidatorBalance(validator_id).call() + return int(self.contract.functions.getValidatorBalance(validator_id).call()) @transaction_method - def recharge_validator_wallet(self, validator_id: int) -> TxRes: + def recharge_validator_wallet(self, validator_id: int) -> ContractFunction: """Pass value kwarg (in wei) to the function when calling it""" return self.contract.functions.rechargeValidatorWallet(validator_id) @transaction_method - def withdraw_funds_from_validator_wallet(self, amount: int) -> TxRes: + def withdraw_funds_from_validator_wallet(self, amount: int) -> ContractFunction: return self.contract.functions.withdrawFundsFromValidatorWallet(amount) diff --git a/skale/contracts/skale_manager_contract.py b/skale/contracts/skale_manager_contract.py new file mode 100644 index 00000000..3c32734f --- /dev/null +++ b/skale/contracts/skale_manager_contract.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE.py +# +# Copyright (C) 2019-Present SKALE Labs +# +# SKALE.py is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# SKALE.py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with SKALE.py. If not, see . +from skale.contracts.base_contract import BaseContract +from skale.skale_manager import SkaleManager + + +class SkaleManagerContract(BaseContract[SkaleManager]): + pass diff --git a/skale/dataclasses/node_info.py b/skale/dataclasses/node_info.py index 36d8ddbc..69ee2bc2 100644 --- a/skale/dataclasses/node_info.py +++ b/skale/dataclasses/node_info.py @@ -19,25 +19,26 @@ from dataclasses import dataclass from skale.dataclasses.skaled_ports import SkaledPorts +from skale.types.node import NodeId, Port @dataclass class NodeInfo(): """Dataclass that represents base info about the node""" - node_id: int + node_id: NodeId name: str - base_port: int + base_port: Port - def calc_ports(self): + def calc_ports(self) -> dict[str, Port]: return { - 'httpRpcPort': self.base_port + SkaledPorts.HTTP_JSON.value, - 'httpsRpcPort': self.base_port + SkaledPorts.HTTPS_JSON.value, - 'wsRpcPort': self.base_port + SkaledPorts.WS_JSON.value, - 'wssRpcPort': self.base_port + SkaledPorts.WSS_JSON.value, - 'infoHttpRpcPort': self.base_port + SkaledPorts.INFO_HTTP_JSON.value + 'httpRpcPort': Port(self.base_port + SkaledPorts.HTTP_JSON.value), + 'httpsRpcPort': Port(self.base_port + SkaledPorts.HTTPS_JSON.value), + 'wsRpcPort': Port(self.base_port + SkaledPorts.WS_JSON.value), + 'wssRpcPort': Port(self.base_port + SkaledPorts.WSS_JSON.value), + 'infoHttpRpcPort': Port(self.base_port + SkaledPorts.INFO_HTTP_JSON.value) } - def to_dict(self): + def to_dict(self) -> dict[str, NodeId | str | Port]: return { 'nodeID': self.node_id, 'nodeName': self.name, diff --git a/skale/dataclasses/schain_options.py b/skale/dataclasses/schain_options.py index 32d74357..d85057f8 100644 --- a/skale/dataclasses/schain_options.py +++ b/skale/dataclasses/schain_options.py @@ -16,8 +16,12 @@ # # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . - +from __future__ import annotations from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from skale.types.schain import SchainOption @dataclass @@ -25,14 +29,14 @@ class SchainOptions: multitransaction_mode: bool threshold_encryption: bool - def to_tuples(self) -> list: + def to_tuples(self) -> list[SchainOption]: return [ ('multitr', bool_to_bytes(self.multitransaction_mode)), ('encrypt', bool_to_bytes(self.threshold_encryption)) ] -def parse_schain_options(raw_options: list) -> SchainOptions: +def parse_schain_options(raw_options: list[SchainOption]) -> SchainOptions: """ Parses raw sChain options from smart contracts (list of tuples). Returns default values if nothing is set on contracts. diff --git a/skale/schain_config/generator.py b/skale/schain_config/generator.py index 4f6f07d8..0a6e5c20 100644 --- a/skale/schain_config/generator.py +++ b/skale/schain_config/generator.py @@ -18,19 +18,27 @@ # along with SKALE.py. If not, see . -def get_nodes_for_schain(skale, name): +from skale.skale_manager import SkaleManager +from skale.types.node import NodeWithId, NodeWithSchains +from skale.types.schain import SchainName + + +def get_nodes_for_schain(skale: SkaleManager, name: SchainName) -> list[NodeWithId]: nodes = [] ids = skale.schains_internal.get_node_ids_for_schain(name) for id_ in ids: node = skale.nodes.get(id_) - node['id'] = id_ - nodes.append(node) + nodes.append(NodeWithId(id=id_, **node)) return nodes -def get_schain_nodes_with_schains(skale, schain_name) -> list: +def get_schain_nodes_with_schains( + skale: SkaleManager, + schain_name: SchainName +) -> list[NodeWithSchains]: """Returns list of nodes for schain with schains for all nodes""" nodes = get_nodes_for_schain(skale, schain_name) - for node in nodes: - node['schains'] = skale.schains.get_schains_for_node(node['id']) - return nodes + return [ + NodeWithSchains(schains=skale.schains.get_schains_for_node(node['id']), **node) + for node in nodes + ] diff --git a/skale/schain_config/ports_allocation.py b/skale/schain_config/ports_allocation.py index 90fa71e2..6c6526f0 100644 --- a/skale/schain_config/ports_allocation.py +++ b/skale/schain_config/ports_allocation.py @@ -17,21 +17,27 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . +from skale.types.node import Port +from skale.types.schain import SchainName, SchainStructure from skale.utils.exceptions import SChainNotFoundException from skale.schain_config import PORTS_PER_SCHAIN -def calc_schain_base_port(node_base_port, schain_index): - return node_base_port + schain_index * PORTS_PER_SCHAIN +def calc_schain_base_port(node_base_port: Port, schain_index: int) -> Port: + return Port(node_base_port + schain_index * PORTS_PER_SCHAIN) -def get_schain_index_in_node(schain_name, node_schains): +def get_schain_index_in_node(schain_name: SchainName, node_schains: list[SchainStructure]) -> int: for index, schain in enumerate(node_schains): - if schain_name == schain['name']: + if schain_name == schain.name: return index raise SChainNotFoundException(f'sChain {schain_name} is not found in the list: {node_schains}') -def get_schain_base_port_on_node(schains_on_node, schain_name, node_base_port): +def get_schain_base_port_on_node( + schains_on_node: list[SchainStructure], + schain_name: SchainName, + node_base_port: Port +) -> Port: schain_index = get_schain_index_in_node(schain_name, schains_on_node) return calc_schain_base_port(node_base_port, schain_index) diff --git a/skale/schain_config/rotation_history.py b/skale/schain_config/rotation_history.py index 7ef488fe..93213b92 100644 --- a/skale/schain_config/rotation_history.py +++ b/skale/schain_config/rotation_history.py @@ -17,35 +17,44 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . +from __future__ import annotations import logging -from collections import namedtuple +from typing import TYPE_CHECKING, Dict, List, TypedDict +from skale.types.rotation import RotationNodeData -from skale import Skale -from skale.contracts.manager.node_rotation import Rotation +if TYPE_CHECKING: + from skale.skale_manager import SkaleManager + from skale.types.dkg import G2Point + from skale.types.node import NodeId + from skale.types.rotation import BlsPublicKey, NodesGroup, Rotation + from skale.types.schain import SchainName logger = logging.getLogger(__name__) -RotationNodeData = namedtuple('RotationNodeData', ['index', 'node_id', 'public_key']) + +class PreviousNodeData(TypedDict): + finish_ts: int + previous_node_id: NodeId def get_previous_schain_groups( - skale: Skale, - schain_name: str, - leaving_node_id=None, -) -> dict: + skale: SkaleManager, + schain_name: SchainName, + leaving_node_id: NodeId | None = None, +) -> Dict[int, NodesGroup]: """ Returns all previous node groups with public keys and finish timestamps. In case of no rotations returns the current state. """ logger.info(f'Collecting rotation history for {schain_name}...') - node_groups = {} + node_groups: dict[int, NodesGroup] = {} - group_id = skale.schains.name_to_group_id(schain_name) + group_id = skale.schains.name_to_id(schain_name) previous_public_keys = skale.key_storage.get_all_previous_public_keys(group_id) current_public_key = skale.key_storage.get_common_public_key(group_id) - rotation = skale.node_rotation.get_rotation_obj(schain_name) + rotation = skale.node_rotation.get_rotation(schain_name) logger.info(f'Rotation data for {schain_name}: {rotation}') @@ -66,12 +75,12 @@ def get_previous_schain_groups( def _add_current_schain_state( - skale: Skale, - node_groups: dict, + skale: SkaleManager, + node_groups: dict[int, NodesGroup], rotation: Rotation, - schain_name: str, - current_public_key: list -) -> dict: + schain_name: SchainName, + current_public_key: G2Point +) -> None: """ Internal function, composes the initial info about the current sChain state and adds it to the node_groups dictionary @@ -91,18 +100,19 @@ def _add_current_schain_state( def _add_previous_schain_rotations_state( - skale: Skale, - node_groups: dict, + skale: SkaleManager, + node_groups: dict[int, NodesGroup], rotation: Rotation, - schain_name: str, - previous_public_keys: list, - leaving_node_id=None -) -> dict: + schain_name: SchainName, + previous_public_keys: list[G2Point], + leaving_node_id: NodeId | None = None +) -> None: """ Internal function, handles rotations from (rotation_counter - 2) to 0 and adds them to the node_groups dictionary """ - previous_nodes = {} + + previous_nodes: Dict[NodeId, PreviousNodeData] = {} for rotation_id in range(rotation.rotation_counter - 1, -1, -1): nodes = node_groups[rotation_id + 1]['nodes'].copy() @@ -112,7 +122,7 @@ def _add_previous_schain_rotations_state( if previous_node is not None: finish_ts = skale.node_rotation.get_schain_finish_ts(previous_node, schain_name) previous_nodes[node_id] = { - 'finish_ts': finish_ts, + 'finish_ts': finish_ts or 0, 'previous_node_id': previous_node } @@ -158,7 +168,7 @@ def _add_previous_schain_rotations_state( break -def _pop_previous_bls_public_key(previous_public_keys): +def _pop_previous_bls_public_key(previous_public_keys: List[G2Point]) -> BlsPublicKey | None: """ Returns BLS public key for the group and removes it from the list, returns None if node with provided node_id was kicked out of the chain because of failed DKG. @@ -169,7 +179,7 @@ def _pop_previous_bls_public_key(previous_public_keys): return bls_keys -def _compose_bls_public_key_info(bls_public_key: str) -> dict: +def _compose_bls_public_key_info(bls_public_key: G2Point) -> BlsPublicKey | None: if bls_public_key: return { 'blsPublicKey0': str(bls_public_key[0][0]), @@ -177,12 +187,17 @@ def _compose_bls_public_key_info(bls_public_key: str) -> dict: 'blsPublicKey2': str(bls_public_key[1][0]), 'blsPublicKey3': str(bls_public_key[1][1]) } + return None -def get_new_nodes_list(skale: Skale, name: str, node_groups) -> list: +def get_new_nodes_list( + skale: SkaleManager, + name: SchainName, + node_groups: Dict[int, NodesGroup] +) -> list[NodeId]: """Returns list of new nodes in for the latest rotation""" logger.info(f'Getting new nodes list for chain {name}') - rotation = skale.node_rotation.get_rotation_obj(name) + rotation = skale.node_rotation.get_rotation(name) current_group_ids = node_groups[rotation.rotation_counter]['nodes'].keys() new_nodes = [] for index in node_groups: diff --git a/skale/skale_allocator.py b/skale/skale_allocator.py index db0f00ce..7fad86e4 100644 --- a/skale/skale_allocator.py +++ b/skale/skale_allocator.py @@ -19,35 +19,24 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING -from web3.constants import ADDRESS_ZERO +from typing import TYPE_CHECKING, List, cast +from web3.constants import CHECKSUM_ADDRESSS_ZERO from skale.skale_base import SkaleBase -import skale.contracts.allocator as contracts -from skale.contracts.contract_manager import ContractManager from skale.utils.contract_info import ContractInfo from skale.utils.contract_types import ContractTypes from skale.utils.helper import get_contracts_info if TYPE_CHECKING: from eth_typing import ChecksumAddress + from skale.contracts.allocator.allocator import Allocator logger = logging.getLogger(__name__) -CONTRACTS_INFO = [ - ContractInfo('contract_manager', 'ContractManager', - ContractManager, ContractTypes.API, False), - ContractInfo('escrow', 'Escrow', contracts.Escrow, - ContractTypes.API, True), - ContractInfo('allocator', 'Allocator', contracts.Allocator, - ContractTypes.API, True) -] - - -def spawn_skale_allocator_lib(skale): - return SkaleAllocator(skale._endpoint, skale._abi_filepath, skale.wallet) +def spawn_skale_allocator_lib(skale: SkaleAllocator) -> SkaleAllocator: + return SkaleAllocator(skale._endpoint, skale.instance.address, skale.wallet) class SkaleAllocator(SkaleBase): @@ -56,10 +45,23 @@ class SkaleAllocator(SkaleBase): def project_name(self) -> str: return 'skale-allocator' - def get_contract_address(self, name) -> ChecksumAddress: + @property + def allocator(self) -> Allocator: + return cast('Allocator', super()._get_contract('allocator')) + + def contracts_info(self) -> List[ContractInfo[SkaleAllocator]]: + import skale.contracts.allocator as contracts + return [ + ContractInfo('escrow', 'Escrow', contracts.Escrow, + ContractTypes.API, True), + ContractInfo('allocator', 'Allocator', contracts.Allocator, + ContractTypes.API, True) + ] + + def get_contract_address(self, name: str) -> ChecksumAddress: if name == 'Escrow': - return ADDRESS_ZERO + return CHECKSUM_ADDRESSS_ZERO return super().get_contract_address(name) - def set_contracts_info(self): - self._SkaleBase__contracts_info = get_contracts_info(CONTRACTS_INFO) + def set_contracts_info(self) -> None: + self._SkaleBase__contracts_info = get_contracts_info(self.contracts_info()) diff --git a/skale/skale_base.py b/skale/skale_base.py index 92adbfff..8c3e19e1 100644 --- a/skale/skale_base.py +++ b/skale/skale_base.py @@ -21,18 +21,18 @@ import abc import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Self, Type from skale_contracts import skale_contracts -from skale.wallets import BaseWallet from skale.utils.exceptions import InvalidWalletError, EmptyWalletError from skale.utils.web3_utils import default_gas_price, init_web3 - -from skale.contracts.contract_manager import ContractManager +from skale.wallets import BaseWallet if TYPE_CHECKING: - from eth_typing import Address, ChecksumAddress + from eth_typing import ChecksumAddress + from skale.contracts.base_contract import BaseContract + from skale.utils.contract_info import ContractInfo logger = logging.getLogger(__name__) @@ -45,9 +45,14 @@ class EmptyPrivateKey(Exception): class SkaleBase: __metaclass__ = abc.ABCMeta - def __init__(self, endpoint, alias_or_address: str, - wallet=None, state_path=None, - ts_diff=None, provider_timeout=30): + def __init__( + self, + endpoint: str, + alias_or_address: str, + wallet: BaseWallet | None = None, + state_path: str | None = None, + ts_diff: int | None = None, + provider_timeout: int = 30): logger.info('Initializing skale.py, endpoint: %s, wallet: %s', endpoint, type(wallet).__name__) self._endpoint = endpoint @@ -58,8 +63,8 @@ def __init__(self, endpoint, alias_or_address: str, self.network = skale_contracts.get_network_by_provider(self.web3.provider) self.project = self.network.get_project(self.project_name) self.instance = self.project.get_instance(alias_or_address) - self.__contracts = {} - self.__contracts_info = {} + self.__contracts: Dict[str, BaseContract[Self]] = {} + self.__contracts_info: Dict[str, ContractInfo[Self]] = {} self.set_contracts_info() if wallet: self.wallet = wallet @@ -70,34 +75,28 @@ def project_name(self) -> str: """Name of smart contracts project""" @property - def gas_price(self): + def gas_price(self) -> int: return default_gas_price(self.web3) @property - def wallet(self): + def wallet(self) -> BaseWallet: if not self._wallet: raise EmptyWalletError('No wallet provided') return self._wallet @wallet.setter - def wallet(self, wallet): + def wallet(self, wallet: BaseWallet) -> None: if issubclass(type(wallet), BaseWallet): self._wallet = wallet else: raise InvalidWalletError(f'Wrong wallet class: {type(wallet).__name__}. \ Must be one of the BaseWallet subclasses') - def __is_debug_contracts(self, abi): - return abi.get('time_helpers_with_debug_address', None) - @abc.abstractmethod - def set_contracts_info(self): - return - - def init_contract_manager(self): - self.add_lib_contract('contract_manager', ContractManager, 'ContractManager') + def set_contracts_info(self) -> None: + pass - def __init_contract_from_info(self, contract_info): + def __init_contract_from_info(self, contract_info: ContractInfo[Self]) -> None: if contract_info.upgradeable: self.init_upgradeable_contract(contract_info) else: @@ -107,7 +106,7 @@ def __init_contract_from_info(self, contract_info): contract_info.contract_name ) - def init_upgradeable_contract(self, contract_info): + def init_upgradeable_contract(self, contract_info: ContractInfo[Self]) -> None: address = self.get_contract_address(contract_info.contract_name) self.add_lib_contract( contract_info.name, @@ -116,31 +115,39 @@ def init_upgradeable_contract(self, contract_info): address ) - def add_lib_contract(self, name: str, contract_class, - contract_name: str, contract_address: Address = None): + def add_lib_contract( + self, + name: str, + contract_class: Type[BaseContract[Self]], + contract_name: str, + contract_address: ChecksumAddress | None = None + ) -> None: address = contract_address or self.instance.get_contract_address(contract_name) logger.debug('Fetching abi for %s, address %s', name, address) contract_abi = self.instance.abi[contract_name] self.add_contract(name, contract_class( self, name, address, contract_abi)) - def add_contract(self, name, contract): + def add_contract(self, name: str, contract: BaseContract[Self]) -> None: self.__contracts[name] = contract - def get_contract_address(self, name) -> ChecksumAddress: + def get_contract_address(self, name: str) -> ChecksumAddress: return self.web3.to_checksum_address( self.instance.get_contract_address(name) ) - def __get_contract_by_name(self, name): - return self.__contracts[name] - - def __getattr__(self, name): + def _get_contract(self, name: str) -> BaseContract[Self]: if name not in self.__contracts: if not self.__contracts_info.get(name): logger.warning("%s method/contract wasn't found", name) - return None + raise ValueError(name, ' is an unknown contract') logger.debug("Contract %s wasn't inited, creating now", name) contract_info = self.__contracts_info[name] self.__init_contract_from_info(contract_info) return self.__get_contract_by_name(name) + + def __get_contract_by_name(self, name: str) -> BaseContract[Self]: + return self.__contracts[name] + + def __getattr__(self, name: str) -> BaseContract[Self]: + return self._get_contract(name) diff --git a/skale/skale_ima.py b/skale/skale_ima.py index a1a07140..ae9f75e4 100644 --- a/skale/skale_ima.py +++ b/skale/skale_ima.py @@ -17,10 +17,11 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . +from __future__ import annotations import logging +from typing import List from skale.skale_base import SkaleBase -import skale.contracts.ima as contracts from skale.utils.contract_info import ContractInfo from skale.utils.contract_types import ContractTypes from skale.utils.helper import get_contracts_info @@ -29,17 +30,22 @@ logger = logging.getLogger(__name__) -CONTRACTS_INFO = [ - ContractInfo('linker', 'Linker', - contracts.Linker, ContractTypes.API, False) -] +class SkaleIma(SkaleBase): + @property + def project_name(self) -> str: + return 'mainnet-ima' + def contracts_info(self) -> List[ContractInfo[SkaleIma]]: + import skale.contracts.ima as contracts + return [ + ContractInfo('linker', 'Linker', + contracts.Linker, ContractTypes.API, False) + ] -def spawn_skale_ima_lib(skale_ima): - """ Clone skale ima object with the same wallet """ - return SkaleIma(skale_ima._endpoint, skale_ima._abi_filepath, skale_ima.wallet) + def set_contracts_info(self) -> None: + self._SkaleBase__contracts_info = get_contracts_info(self.contracts_info()) -class SkaleIma(SkaleBase): - def set_contracts_info(self): - self._SkaleBase__contracts_info = get_contracts_info(CONTRACTS_INFO) +def spawn_skale_ima_lib(skale_ima: SkaleIma) -> SkaleIma: + """ Clone skale ima object with the same wallet """ + return SkaleIma(skale_ima._endpoint, skale_ima.instance.address, skale_ima.wallet) diff --git a/skale/skale_manager.py b/skale/skale_manager.py index 592b6848..ce12c927 100644 --- a/skale/skale_manager.py +++ b/skale/skale_manager.py @@ -17,67 +17,20 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . +from __future__ import annotations import logging +from typing import List, TYPE_CHECKING, cast from skale.skale_base import SkaleBase -import skale.contracts.manager as contracts -from skale.contracts.contract_manager import ContractManager from skale.utils.contract_info import ContractInfo from skale.utils.contract_types import ContractTypes from skale.utils.helper import get_contracts_info - -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + import skale.contracts.manager as contracts -CONTRACTS_INFO = [ - ContractInfo('contract_manager', 'ContractManager', - ContractManager, ContractTypes.API, False), - ContractInfo('token', 'SkaleToken', contracts.Token, ContractTypes.API, - False), - ContractInfo('manager', 'SkaleManager', contracts.Manager, - ContractTypes.API, True), - ContractInfo('constants_holder', 'ConstantsHolder', contracts.ConstantsHolder, - ContractTypes.INTERNAL, True), - ContractInfo('nodes', 'Nodes', contracts.Nodes, - ContractTypes.API, True), - ContractInfo('node_rotation', 'NodeRotation', contracts.NodeRotation, - ContractTypes.API, True), - ContractInfo('schains', 'Schains', contracts.SChains, - ContractTypes.API, True), - ContractInfo('schains_internal', 'SchainsInternal', contracts.SChainsInternal, - ContractTypes.API, True), - ContractInfo('dkg', 'SkaleDKG', contracts.DKG, ContractTypes.API, True), - ContractInfo('key_storage', 'KeyStorage', - contracts.KeyStorage, ContractTypes.API, True), - ContractInfo('delegation_controller', 'DelegationController', contracts.DelegationController, - ContractTypes.API, False), - ContractInfo('delegation_period_manager', 'DelegationPeriodManager', - contracts.DelegationPeriodManager, ContractTypes.API, False), - ContractInfo('validator_service', 'ValidatorService', contracts.ValidatorService, - ContractTypes.API, False), - ContractInfo('token_state', 'TokenState', contracts.TokenState, - ContractTypes.API, False), - ContractInfo('distributor', 'Distributor', contracts.Distributor, - ContractTypes.API, False), - ContractInfo('slashing_table', 'SlashingTable', contracts.SlashingTable, - ContractTypes.API, False), - ContractInfo('wallets', 'Wallets', contracts.Wallets, - ContractTypes.API, True), - ContractInfo('bounty_v2', 'BountyV2', contracts.BountyV2, - ContractTypes.API, True), - ContractInfo('punisher', 'Punisher', contracts.Punisher, - ContractTypes.API, True), - ContractInfo('sync_manager', 'SyncManager', contracts.SyncManager, - ContractTypes.API, False), - ContractInfo('time_helpers_with_debug', 'TimeHelpersWithDebug', contracts.TimeHelpersWithDebug, - ContractTypes.API, False) -] - - -def spawn_skale_manager_lib(skale): - """ Clone skale manager object with the same wallet """ - return SkaleManager(skale._endpoint, skale._abi_filepath, skale.wallet) +logger = logging.getLogger(__name__) class SkaleManager(SkaleBase): @@ -86,6 +39,149 @@ class SkaleManager(SkaleBase): def project_name(self) -> str: return 'skale-manager' - def set_contracts_info(self): + @staticmethod + def contracts_info() -> List[ContractInfo[SkaleManager]]: + import skale.contracts.manager as contracts + return [ + ContractInfo('contract_manager', 'ContractManager', + contracts.ContractManager, ContractTypes.API, False), + ContractInfo('token', 'SkaleToken', contracts.Token, ContractTypes.API, + False), + ContractInfo('manager', 'SkaleManager', contracts.Manager, + ContractTypes.API, True), + ContractInfo('constants_holder', 'ConstantsHolder', contracts.ConstantsHolder, + ContractTypes.INTERNAL, True), + ContractInfo('nodes', 'Nodes', contracts.Nodes, + ContractTypes.API, True), + ContractInfo('node_rotation', 'NodeRotation', contracts.NodeRotation, + ContractTypes.API, True), + ContractInfo('schains', 'Schains', contracts.SChains, + ContractTypes.API, True), + ContractInfo('schains_internal', 'SchainsInternal', contracts.SChainsInternal, + ContractTypes.API, True), + ContractInfo('dkg', 'SkaleDKG', contracts.DKG, ContractTypes.API, True), + ContractInfo('key_storage', 'KeyStorage', + contracts.KeyStorage, ContractTypes.API, True), + ContractInfo('delegation_controller', 'DelegationController', + contracts.DelegationController, ContractTypes.API, False), + ContractInfo('delegation_period_manager', 'DelegationPeriodManager', + contracts.DelegationPeriodManager, ContractTypes.API, False), + ContractInfo('validator_service', 'ValidatorService', + contracts.ValidatorService, ContractTypes.API, False), + ContractInfo('token_state', 'TokenState', contracts.TokenState, + ContractTypes.API, False), + ContractInfo('distributor', 'Distributor', contracts.Distributor, + ContractTypes.API, False), + ContractInfo('slashing_table', 'SlashingTable', contracts.SlashingTable, + ContractTypes.API, False), + ContractInfo('wallets', 'Wallets', contracts.Wallets, + ContractTypes.API, True), + ContractInfo('bounty_v2', 'BountyV2', contracts.BountyV2, + ContractTypes.API, True), + ContractInfo('punisher', 'Punisher', contracts.Punisher, + ContractTypes.API, True), + ContractInfo('sync_manager', 'SyncManager', contracts.SyncManager, + ContractTypes.API, False), + ContractInfo('time_helpers_with_debug', 'TimeHelpersWithDebug', + contracts.TimeHelpersWithDebug, ContractTypes.API, False) + ] + + @property + def bounty_v2(self) -> contracts.BountyV2: + return cast('contracts.BountyV2', self._get_contract('bounty_v2')) + + @property + def constants_holder(self) -> contracts.ConstantsHolder: + return cast('contracts.ConstantsHolder', self._get_contract('constants_holder')) + + @property + def contract_manager(self) -> contracts.ContractManager: + return cast('contracts.ContractManager', self._get_contract('contract_manager')) + + @property + def delegation_controller(self) -> contracts.DelegationController: + return cast('contracts.DelegationController', self._get_contract('delegation_controller')) + + @property + def delegation_period_manager(self) -> contracts.DelegationPeriodManager: + return cast( + 'contracts.DelegationPeriodManager', + self._get_contract('delegation_period_manager') + ) + + @property + def distributor(self) -> contracts.Distributor: + return cast('contracts.Distributor', self._get_contract('distributor')) + + @property + def dkg(self) -> contracts.DKG: + return cast('contracts.DKG', self._get_contract('dkg')) + + @property + def key_storage(self) -> contracts.KeyStorage: + return cast('contracts.KeyStorage', self._get_contract('key_storage')) + + @property + def manager(self) -> contracts.Manager: + return cast('contracts.Manager', self._get_contract('manager')) + + @property + def node_rotation(self) -> contracts.NodeRotation: + return cast('contracts.NodeRotation', self._get_contract('node_rotation')) + + @property + def nodes(self) -> contracts.Nodes: + return cast('contracts.Nodes', self._get_contract('nodes')) + + @property + def punisher(self) -> contracts.Punisher: + return cast('contracts.Punisher', self._get_contract('punisher')) + + @property + def schains(self) -> contracts.SChains: + return cast('contracts.SChains', self._get_contract('schains')) + + @property + def schains_internal(self) -> contracts.SChainsInternal: + return cast('contracts.SChainsInternal', self._get_contract('schains_internal')) + + @property + def slashing_table(self) -> contracts.SlashingTable: + return cast('contracts.SlashingTable', self._get_contract('slashing_table')) + + @property + def sync_manager(self) -> contracts.SyncManager: + return cast('contracts.SyncManager', self._get_contract('sync_manager')) + + @property + def time_helpers_with_debug(self) -> contracts.TimeHelpersWithDebug: + return cast('contracts.TimeHelpersWithDebug', self._get_contract('time_helpers_with_debug')) + + @property + def token(self) -> contracts.Token: + return cast('contracts.Token', self._get_contract('token')) + + @property + def token_state(self) -> contracts.TokenState: + return cast('contracts.TokenState', self._get_contract('token_state')) + + @property + def validator_service(self) -> contracts.ValidatorService: + return cast('contracts.ValidatorService', self._get_contract('validator_service')) + + @property + def wallets(self) -> contracts.Wallets: + return cast('contracts.Wallets', self._get_contract('wallets')) + + def init_contract_manager(self) -> None: + from skale.contracts.manager.contract_manager import ContractManager + self.add_lib_contract('contract_manager', ContractManager, 'ContractManager') + + def set_contracts_info(self) -> None: self.init_contract_manager() - self._SkaleBase__contracts_info = get_contracts_info(CONTRACTS_INFO) + self._SkaleBase__contracts_info = get_contracts_info(self.contracts_info()) + + +def spawn_skale_manager_lib(skale: SkaleManager) -> SkaleManager: + """ Clone skale manager object with the same wallet """ + return SkaleManager(skale._endpoint, skale.instance.address, skale.wallet) diff --git a/skale/transactions/result.py b/skale/transactions/result.py index f5c0251f..e29372b1 100644 --- a/skale/transactions/result.py +++ b/skale/transactions/result.py @@ -18,7 +18,10 @@ # along with SKALE.py. If not, see . import enum -from typing import NamedTuple +from typing import Mapping, NamedTuple + +from web3.types import TxReceipt +from eth_typing import HexStr from skale.transactions.exceptions import ( DryRunFailedError, @@ -36,11 +39,16 @@ class TxCallResult(NamedTuple): status: TxStatus error: str message: str - data: dict + data: Mapping[str, str | int] class TxRes: - def __init__(self, tx_call_result=None, tx_hash=None, receipt=None, revert=None): + def __init__( + self, + tx_call_result: TxCallResult | None = None, + tx_hash: HexStr | None = None, + receipt: TxReceipt | None = None + ): self.tx_call_result = tx_call_result self.tx_hash = tx_hash self.receipt = receipt @@ -61,7 +69,10 @@ def __repr__(self) -> str: def raise_for_status(self) -> None: if self.receipt is not None: if self.receipt['status'] == TxStatus.FAILED: - raise TransactionFailedError(self.receipt) + raise TransactionFailedError( + "Tx status is failed", + {key: str(value) for key, value in self.receipt.items()} + ) elif self.tx_call_result is not None and self.tx_call_result.status == TxStatus.FAILED: if self.tx_call_result.error == 'revert': raise DryRunRevertError(self.tx_call_result.message) diff --git a/skale/transactions/tools.py b/skale/transactions/tools.py index f0451933..f6153d8b 100644 --- a/skale/transactions/tools.py +++ b/skale/transactions/tools.py @@ -17,20 +17,27 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . +from __future__ import annotations import logging import time from functools import partial, wraps -from typing import Dict, Optional +from typing import Any, Callable, Optional, TYPE_CHECKING +from eth_typing import ChecksumAddress from web3 import Web3 +from web3.contract.contract import ContractFunction from web3.exceptions import ContractLogicError, Web3Exception from web3._utils.transactions import get_block_gas_limit +from web3.types import Nonce, TxParams, Wei import skale.config as config from skale.transactions.exceptions import TransactionError from skale.transactions.result import TxCallResult, TxRes, TxStatus from skale.utils.web3_utils import get_eth_nonce +if TYPE_CHECKING: + from skale.skale_base import SkaleBase + logger = logging.getLogger(__name__) @@ -38,11 +45,16 @@ DEFAULT_ETH_SEND_GAS_LIMIT = 22000 -def make_dry_run_call(skale, method, gas_limit=None, value=0) -> TxCallResult: - opts = { +def make_dry_run_call( + skale: SkaleBase, + method: ContractFunction, + gas_limit: int | None = None, + value: Wei = Wei(0) +) -> TxCallResult: + opts = TxParams({ 'from': skale.wallet.address, 'value': value - } + }) logger.info( f'Dry run tx: {method.fn_name}, ' f'sender: {skale.wallet.address}, ' @@ -60,17 +72,28 @@ def make_dry_run_call(skale, method, gas_limit=None, value=0) -> TxCallResult: estimated_gas = estimate_gas(skale.web3, method, opts) logger.info(f'Estimated gas for {method.fn_name}: {estimated_gas}') except ContractLogicError as e: - return TxCallResult(status=TxStatus.FAILED, - error='revert', message=e.message, data=e.data) + message = e.message or 'Contract logic error' + error_data = e.data or {} + data = {'data': error_data} if isinstance(error_data, str) else error_data + return TxCallResult( + status=TxStatus.FAILED, + error='revert', + message=message, + data=data + ) except (Web3Exception, ValueError) as e: logger.exception('Dry run for %s failed', method) return TxCallResult(status=TxStatus.FAILED, error='exception', message=str(e), data={}) - return TxCallResult(status=TxStatus.SUCCESS, error='', - message='success', data={'gas': estimated_gas}) + return TxCallResult( + status=TxStatus.SUCCESS, + error='', + message='success', + data={'gas': estimated_gas} + ) -def estimate_gas(web3, method, opts): +def estimate_gas(web3: Web3, method: ContractFunction, opts: TxParams) -> int: try: block_gas_limit = get_block_gas_limit(web3) except AttributeError: @@ -90,25 +113,25 @@ def estimate_gas(web3, method, opts): return normalized_estimated_gas -def build_tx_dict(method, *args, **kwargs): +def build_tx_dict(method: ContractFunction, *args: Any, **kwargs: Any) -> TxParams: base_fields = compose_base_fields(*args, **kwargs) return method.build_transaction(base_fields) def compose_base_fields( - nonce: int, + nonce: Nonce, gas_limit: int, - gas_price: Optional[int] = None, - max_fee_per_gas: Optional[int] = None, - max_priority_fee_per_gas: Optional[int] = None, - value: Optional[int] = 0, -) -> Dict: - fee_fields = { + gas_price: Wei | None = None, + max_fee_per_gas: Wei | None = None, + max_priority_fee_per_gas: Wei | None = None, + value: Wei = Wei(0), +) -> TxParams: + fee_fields = TxParams({ 'gas': gas_limit, 'nonce': nonce, 'value': value - } - if max_priority_fee_per_gas is not None: + }) + if max_priority_fee_per_gas is not None and max_fee_per_gas is not None: fee_fields.update({ 'maxPriorityFeePerGas': max_priority_fee_per_gas, 'maxFeePerGas': max_fee_per_gas @@ -121,12 +144,12 @@ def compose_base_fields( def transaction_from_method( - method, + method: ContractFunction, *, multiplier: Optional[float] = None, priority: Optional[int] = None, - **kwargs -) -> str: + **kwargs: Any +) -> TxParams: tx = build_tx_dict(method, **kwargs) logger.info( f'Tx: {method.fn_name}, ' @@ -137,11 +160,11 @@ def transaction_from_method( def compose_eth_transfer_tx( web3: Web3, - from_address: str, - to_address: str, - value: int, - **kwargs -) -> Dict: + from_address: ChecksumAddress, + to_address: ChecksumAddress, + value: Wei, + **kwargs: Any +) -> TxParams: nonce = get_eth_nonce(web3, from_address) base_fields = compose_base_fields( nonce=nonce, @@ -149,20 +172,25 @@ def compose_eth_transfer_tx( value=value, **kwargs ) - tx = { + tx = TxParams({ 'from': from_address, 'to': to_address, **base_fields - } + }) return tx -def retry_tx(tx=None, *, max_retries=3, timeout=-1): +def retry_tx( + tx: Callable[..., TxRes] | None = None, + *, + max_retries: int = 3, + timeout: int = -1 +) -> Callable[..., TxRes | None] | partial[Any]: if tx is None: return partial(retry_tx, max_retries=3, timeout=timeout) @wraps(tx) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> TxRes | None: return run_tx_with_retry( tx, *args, max_retries=max_retries, @@ -171,10 +199,14 @@ def wrapper(*args, **kwargs): return wrapper -def run_tx_with_retry(transaction, *args, max_retries=3, - retry_timeout=-1, - raise_for_status=True, - **kwargs) -> TxRes: +def run_tx_with_retry( + transaction: Callable[..., TxRes], + *args: Any, + max_retries: int = 3, + retry_timeout: int = -1, + raise_for_status: bool = True, + **kwargs: Any +) -> TxRes | None: attempt = 0 tx_res = None exp_timeout = 1 diff --git a/skale/types/allocation.py b/skale/types/allocation.py new file mode 100644 index 00000000..b61765c9 --- /dev/null +++ b/skale/types/allocation.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE.py +# +# Copyright (C) 2024-Present SKALE Labs +# +# SKALE.py is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# SKALE.py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with SKALE.py. If not, see . + +from enum import IntEnum +from typing import NewType, TypedDict + +from web3.types import Wei + + +class TimeUnit(IntEnum): + DAY = 0 + MONTH = 1 + YEAR = 2 + + +class BeneficiaryStatus(IntEnum): + UNKNOWN = 0 + CONFIRMED = 1 + ACTIVE = 2 + TERMINATED = 3 + + +PlanId = NewType('PlanId', int) + + +class Plan(TypedDict): + totalVestingDuration: int + vestingCliff: int + vestingIntervalTimeUnit: TimeUnit + vestingInterval: int + isDelegationAllowed: bool + isTerminatable: bool + + +class PlanWithId(Plan): + planId: PlanId + + +class BeneficiaryPlan(TypedDict): + status: BeneficiaryStatus + statusName: str + planId: PlanId + startMonth: int + fullAmount: Wei + amountAfterLockup: Wei diff --git a/skale/types/delegation.py b/skale/types/delegation.py new file mode 100644 index 00000000..896cbf78 --- /dev/null +++ b/skale/types/delegation.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE.py +# +# Copyright (C) 2024-Present SKALE Labs +# +# SKALE.py is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# SKALE.py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with SKALE.py. If not, see . + +from enum import Enum +from typing import NewType, TypedDict + +from eth_typing import ChecksumAddress +from web3.types import Wei + +from skale.types.validator import ValidatorId + + +DelegationId = NewType('DelegationId', int) + + +class DelegationStatus(Enum): + PROPOSED = 0 + ACCEPTED = 1 + CANCELED = 2 + REJECTED = 3 + DELEGATED = 4 + UNDELEGATION_REQUESTED = 5 + COMPLETED = 6 + + +class Delegation(TypedDict): + address: ChecksumAddress + validator_id: ValidatorId + amount: Wei + delegation_period: int + created: int + started: int + finished: int + info: str + + +class FullDelegation(Delegation): + id: DelegationId + status: DelegationStatus diff --git a/skale/types/dkg.py b/skale/types/dkg.py new file mode 100644 index 00000000..27d13b45 --- /dev/null +++ b/skale/types/dkg.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE.py +# +# Copyright (C) 2024-Present SKALE Labs +# +# SKALE.py is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# SKALE.py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with SKALE.py. If not, see . + +from collections import namedtuple +from typing import List, NamedTuple, NewType, Tuple + +from eth_typing import HexStr + + +Fp2Point = namedtuple('Fp2Point', ['a', 'b']) + + +class G2Point(NamedTuple): + x: Fp2Point + y: Fp2Point + + +VerificationVector = NewType('VerificationVector', List[G2Point]) + + +class KeyShare(NamedTuple): + publicKey: Tuple[bytes | HexStr, bytes | HexStr] + share: bytes | HexStr diff --git a/skale/types/node.py b/skale/types/node.py new file mode 100644 index 00000000..8f91316a --- /dev/null +++ b/skale/types/node.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE.py +# +# Copyright (C) 2024-Present SKALE Labs +# +# SKALE.py is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# SKALE.py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with SKALE.py. If not, see . + +from enum import IntEnum +from typing import List, NewType, TypedDict + +from eth_typing import BlockNumber + +from skale.types.schain import SchainStructureWithStatus +from skale.types.validator import ValidatorId + + +NodeId = NewType("NodeId", int) +Port = NewType("Port", int) + + +class NodeStatus(IntEnum): + ACTIVE = 0 + LEAVING = 1 + LEFT = 2 + IN_MAINTENANCE = 3 + + +class Node(TypedDict): + name: str + ip: bytes + publicIP: bytes + port: Port + start_block: BlockNumber + last_reward_date: int + finish_time: int + status: NodeStatus + validator_id: ValidatorId + publicKey: str + domain_name: str + + +class NodeWithId(Node): + id: NodeId + + +class NodeWithSchains(NodeWithId): + schains: List[SchainStructureWithStatus] diff --git a/skale/types/rotation.py b/skale/types/rotation.py new file mode 100644 index 00000000..78e22ad7 --- /dev/null +++ b/skale/types/rotation.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE.py +# +# Copyright (C) 2024-Present SKALE Labs +# +# SKALE.py is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# SKALE.py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with SKALE.py. If not, see . + +from collections import namedtuple +from dataclasses import dataclass +from typing import TypedDict + +from skale.types.node import NodeId +from skale.types.schain import SchainHash + + +RotationNodeData = namedtuple('RotationNodeData', ['index', 'node_id', 'public_key']) + + +class NodesSwap(TypedDict): + leaving_node_id: NodeId + new_node_id: NodeId + + +class BlsPublicKey(TypedDict): + blsPublicKey0: str + blsPublicKey1: str + blsPublicKey2: str + blsPublicKey3: str + + +class NodesGroup(TypedDict): + rotation: NodesSwap | None + nodes: dict[NodeId, RotationNodeData] + finish_ts: int | None + bls_public_key: BlsPublicKey | None + + +@dataclass +class Rotation: + leaving_node_id: NodeId + new_node_id: NodeId + freeze_until: int + rotation_counter: int + + +class RotationSwap(TypedDict): + schain_id: SchainHash + finished_rotation: int diff --git a/skale/types/schain.py b/skale/types/schain.py new file mode 100644 index 00000000..c7f07da1 --- /dev/null +++ b/skale/types/schain.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE.py +# +# Copyright (C) 2024-Present SKALE Labs +# +# SKALE.py is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# SKALE.py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with SKALE.py. If not, see . + +from dataclasses import dataclass +from typing import NewType + +from eth_typing import ChecksumAddress +from web3.types import Wei + +from skale.dataclasses.schain_options import SchainOptions + + +SchainName = NewType('SchainName', str) +SchainHash = NewType('SchainHash', bytes) +SchainOption = tuple[str, bytes] + + +@dataclass +class Schain: + name: SchainName + mainnetOwner: ChecksumAddress + indexInOwnerList: int + partOfNode: int + lifetime: int + startDate: int + startBlock: int + deposit: Wei + index: int + generation: int + originator: ChecksumAddress + + +@dataclass +class SchainStructure(Schain): + chainId: SchainHash + options: SchainOptions + + +@dataclass +class SchainStructureWithStatus(SchainStructure): + active: bool diff --git a/skale/types/validator.py b/skale/types/validator.py new file mode 100644 index 00000000..9eeaebc1 --- /dev/null +++ b/skale/types/validator.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# +# This file is part of SKALE.py +# +# Copyright (C) 2024-Present SKALE Labs +# +# SKALE.py is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# SKALE.py is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with SKALE.py. If not, see . + +from typing import NewType, TypedDict + +from eth_typing import ChecksumAddress +from web3.types import Wei + + +ValidatorId = NewType('ValidatorId', int) + + +class Validator(TypedDict): + name: str + validator_address: ChecksumAddress + requested_address: ChecksumAddress + description: str + fee_rate: int + registration_time: int + minimum_delegation_amount: Wei + accept_new_requests: bool + trusted: bool + + +class ValidatorWithId(Validator): + id: ValidatorId diff --git a/skale/utils/account_tools.py b/skale/utils/account_tools.py index cb438fe1..6b057aeb 100644 --- a/skale/utils/account_tools.py +++ b/skale/utils/account_tools.py @@ -18,11 +18,16 @@ # along with SKALE.py. If not, see . """ Account utilities """ +from __future__ import annotations +from decimal import Decimal import logging -from typing import Optional +from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Type, TypedDict +from eth_typing import ChecksumAddress from web3 import Web3 +from web3.types import TxReceipt, Wei +from skale.transactions.result import TxRes from skale.transactions.tools import compose_eth_transfer_tx from skale.utils.constants import LONG_LINE from skale.wallets import LedgerWallet, Web3Wallet @@ -32,26 +37,40 @@ wait_for_confirmation_blocks ) +if TYPE_CHECKING: + from skale.skale_manager import SkaleManager + from skale.wallets.common import BaseWallet + + logger = logging.getLogger(__name__) -WALLET_TYPE_TO_CLASS = { +class AccountData(TypedDict): + address: ChecksumAddress + private_key: str + + +WALLET_TYPE_TO_CLASS: Dict[str, Type[LedgerWallet] | Type[Web3Wallet]] = { 'ledger': LedgerWallet, 'web3': Web3Wallet } -def create_wallet(wallet_type='web3', *args, **kwargs): +def create_wallet( + wallet_type: Literal['web3'] | Literal['ledger'] = 'web3', + *args: Any, + **kwargs: Any +) -> LedgerWallet | Web3Wallet: return WALLET_TYPE_TO_CLASS[wallet_type](*args, **kwargs) def send_tokens( - skale, - receiver_address, - amount, - *args, - **kwargs -): + skale: SkaleManager, + receiver_address: ChecksumAddress, + amount: Wei, + *args: Any, + **kwargs: Any +) -> TxRes: logger.info( f'Sending {amount} SKALE tokens from {skale.wallet.address} => ' f'{receiver_address}' @@ -68,17 +87,17 @@ def send_tokens( def send_eth( web3: Web3, - wallet, - receiver_address, - amount, - *args, + wallet: BaseWallet, + receiver_address: ChecksumAddress, + amount: Wei, + *args: Any, gas_price: Optional[int] = None, wait_for: bool = True, confirmation_blocks: int = 0, multiplier: Optional[int] = None, priority: Optional[int] = None, - **kwargs -): + **kwargs: Any +) -> TxReceipt: logger.info( f'Sending {amount} ETH from {wallet.address} => ' f'{receiver_address}' @@ -86,12 +105,12 @@ def send_eth( wei_amount = web3.to_wei(amount, 'ether') gas_price = gas_price or default_gas_price(web3) tx = compose_eth_transfer_tx( - web3=web3, - *args, + web3, + wallet.address, + receiver_address, + wei_amount, gas_price=gas_price, - from_address=wallet.address, - to_address=receiver_address, - value=wei_amount, + *args, **kwargs ) tx_hash = wallet.sign_and_send( @@ -110,11 +129,11 @@ def send_eth( return receipt -def account_eth_balance_wei(web3, address): +def account_eth_balance_wei(web3: Web3, address: ChecksumAddress) -> Wei: return web3.eth.get_balance(address) -def check_ether_balance(web3, address): +def check_ether_balance(web3: Web3, address: ChecksumAddress) -> int | Decimal: balance_wei = account_eth_balance_wei(web3, address) balance = web3.from_wei(balance_wei, 'ether') @@ -122,26 +141,28 @@ def check_ether_balance(web3, address): return balance -def check_skale_balance(skale, address): +def check_skale_balance(skale: SkaleManager, address: ChecksumAddress) -> int | Decimal: balance_wei = skale.token.get_balance(address) balance = skale.web3.from_wei(balance_wei, 'ether') logger.info(f'{address} balance: {balance} SKALE') return balance -def generate_account(web3): +def generate_account(web3: Web3) -> AccountData: account = web3.eth.account.create() private_key = account.key.hex() logger.info(f'Generated account: {account.address}') - return {'address': account.address, 'private_key': private_key} + return AccountData({'address': account.address, 'private_key': private_key}) -def generate_accounts(skale, - base_wallet, - n_wallets, - skale_amount, - eth_amount, - debug=False): +def generate_accounts( + skale: SkaleManager, + base_wallet: BaseWallet, + n_wallets: int, + skale_amount: Wei, + eth_amount: Wei, + debug: bool = False +) -> List[AccountData]: n_wallets = int(n_wallets) results = [] diff --git a/skale/utils/contract_info.py b/skale/utils/contract_info.py index 33cfe1d6..aaf1c77f 100644 --- a/skale/utils/contract_info.py +++ b/skale/utils/contract_info.py @@ -18,15 +18,19 @@ # along with SKALE.py. If not, see . """ Contract info utilities """ -from typing import NamedTuple +from __future__ import annotations +from typing import Generic, NamedTuple, Type, TYPE_CHECKING -from skale.contracts.base_contract import BaseContract -from skale.utils.contract_types import ContractTypes +from skale.contracts.base_contract import SkaleType +if TYPE_CHECKING: + from skale.contracts.base_contract import BaseContract + from skale.utils.contract_types import ContractTypes -class ContractInfo(NamedTuple): + +class ContractInfo(NamedTuple, Generic[SkaleType]): name: str contract_name: str - contract_class: BaseContract + contract_class: Type[BaseContract[SkaleType]] type: ContractTypes upgradeable: bool diff --git a/skale/utils/contracts_provision/__init__.py b/skale/utils/contracts_provision/__init__.py index bf403f69..645f9aa5 100644 --- a/skale/utils/contracts_provision/__init__.py +++ b/skale/utils/contracts_provision/__init__.py @@ -17,7 +17,10 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from skale.contracts.allocator.allocator import TimeUnit +from web3.types import Wei + +from skale.types.allocation import TimeUnit +from skale.types.validator import ValidatorId # manager test constants @@ -29,7 +32,7 @@ DEFAULT_DOMAIN_NAME = 'skale.test' INITIAL_DELEGATION_PERIOD = 2 -D_VALIDATOR_ID = 1 +D_VALIDATOR_ID = ValidatorId(1) D_VALIDATOR_NAME = 'test' D_VALIDATOR_DESC = 'test' D_VALIDATOR_FEE = 10 @@ -60,6 +63,6 @@ POLL_INTERVAL = 2 -TEST_SKALE_AMOUNT = 100000 +TEST_SKALE_AMOUNT = Wei(100000) D_PLAN_ID = 1 diff --git a/skale/utils/contracts_provision/allocator.py b/skale/utils/contracts_provision/allocator.py index b5bbd017..b7d4dc15 100644 --- a/skale/utils/contracts_provision/allocator.py +++ b/skale/utils/contracts_provision/allocator.py @@ -19,16 +19,23 @@ from time import sleep +from web3.contract.contract import ContractEvent +from web3.types import LogReceipt, Wei +from web3._utils.filters import LogFilter + +from skale.skale_allocator import SkaleAllocator +from skale.skale_manager import SkaleManager from skale.utils.account_tools import send_tokens from skale.utils.contracts_provision import ( TEST_SKALE_AMOUNT, TEST_VESTING_CLIFF, TEST_TOTAL_VESTING_DURATION, TEST_VESTING_INTERVAL_TIME_UNIT, TEST_VESTING_INTERVAL, TEST_CAN_DELEGATE, TEST_IS_TERMINATABLE, POLL_INTERVAL, TEST_START_MONTH, TEST_FULL_AMOUNT, TEST_LOCKUP_AMOUNT ) +from skale.wallets.common import BaseWallet -def _catch_event(event_obj): - event_filter = event_obj.createFilter( +def _catch_event(event_obj: ContractEvent) -> LogReceipt: + event_filter: LogFilter = event_obj.create_filter( fromBlock=0, toBlock='latest' ) @@ -38,7 +45,11 @@ def _catch_event(event_obj): sleep(POLL_INTERVAL) -def transfer_tokens_to_allocator(skale_manager, skale_allocator, amount=TEST_SKALE_AMOUNT): +def transfer_tokens_to_allocator( + skale_manager: SkaleManager, + skale_allocator: SkaleAllocator, + amount: Wei = TEST_SKALE_AMOUNT +) -> None: send_tokens(skale_manager, skale_allocator.allocator.address, amount) @@ -46,7 +57,7 @@ def transfer_tokens_to_allocator(skale_manager, skale_allocator, amount=TEST_SKA # send_tokens(skale, skale.wallet, skale.token_launch_manager.address, amount) -def add_test_plan(skale_allocator): +def add_test_plan(skale_allocator: SkaleAllocator) -> int: skale_allocator.allocator.add_plan( vesting_cliff=TEST_VESTING_CLIFF, total_vesting_duration=TEST_TOTAL_VESTING_DURATION, @@ -58,7 +69,11 @@ def add_test_plan(skale_allocator): return len(skale_allocator.allocator.get_all_plans()) -def connect_test_beneficiary(skale_allocator, plan_id, wallet): +def connect_test_beneficiary( + skale_allocator: SkaleAllocator, + plan_id: int, + wallet: BaseWallet +) -> None: skale_allocator.allocator.connect_beneficiary_to_plan( beneficiary_address=wallet.address, plan_id=plan_id, diff --git a/skale/utils/contracts_provision/fake_multisig_contract.py b/skale/utils/contracts_provision/fake_multisig_contract.py index 915e0f99..19ac2c25 100644 --- a/skale/utils/contracts_provision/fake_multisig_contract.py +++ b/skale/utils/contracts_provision/fake_multisig_contract.py @@ -20,11 +20,13 @@ import os import json -from skale.transactions.tools import transaction_from_method +from web3 import Web3 + from skale.utils.web3_utils import ( get_eth_nonce, wait_for_receipt_by_blocks ) +from skale.wallets.common import BaseWallet # Usage note: to change this contract update the code, compile it and put the new bytecode and # new ABI below @@ -76,22 +78,20 @@ FAKE_MULTISIG_CONSTRUCTOR_GAS = 1000000 -def deploy_fake_multisig_contract(web3, wallet): +def deploy_fake_multisig_contract(web3: Web3, wallet: BaseWallet) -> None: print('Going to deploy simple payable contract') FakeMultisigContract = web3.eth.contract(abi=FAKE_MULTISIG_ABI, bytecode=FAKE_MULTISIG_BYTECODE) constructor = FakeMultisigContract.constructor() - constructor.fn_name = 'fake_multisig_constructor' - tx = transaction_from_method( - constructor, - nonce=get_eth_nonce(web3, wallet.address), - gas_price=3 * 10 ** 9, - gas_limit=FAKE_MULTISIG_CONSTRUCTOR_GAS - ) + tx = constructor.build_transaction({ + 'nonce': get_eth_nonce(web3, wallet.address), + 'gasPrice': 3 * 10 ** 9, + 'gas': FAKE_MULTISIG_CONSTRUCTOR_GAS + }) tx_hash = wallet.sign_and_send(tx) receipt = wait_for_receipt_by_blocks(web3, tx_hash) - print(f'Sample contract successfully deployed: {receipt.contractAddress}') + print(f"Sample contract successfully deployed: {receipt['contractAddress']}") content = { - 'address': receipt.contractAddress, + 'address': receipt['contractAddress'], 'abi': FAKE_MULTISIG_ABI } with open(FAKE_MULTISIG_DATA_PATH, 'w') as outfile: diff --git a/skale/utils/contracts_provision/main.py b/skale/utils/contracts_provision/main.py index 38542390..b554903e 100644 --- a/skale/utils/contracts_provision/main.py +++ b/skale/utils/contracts_provision/main.py @@ -17,10 +17,16 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . +from typing import List, Tuple +from eth_typing import ChecksumAddress from web3 import Web3 +from web3.types import RPCEndpoint -from skale.contracts.manager.nodes import NodeStatus +from skale.dataclasses.schain_options import SchainOptions +from skale.skale_manager import SkaleManager from skale.transactions.result import TxRes +from skale.types.node import NodeId, NodeStatus +from skale.types.validator import ValidatorId from skale.utils.contracts_provision import ( D_VALIDATOR_ID, D_VALIDATOR_MIN_DEL, @@ -48,30 +54,31 @@ def _skip_evm_time(web3: Web3, seconds: int, mine: bool = True) -> int: """For test purposes only, works only with hardhat node""" - res = web3.provider.make_request('evm_increaseTime', [seconds]) - web3.provider.make_request("evm_mine", []) + res = web3.provider.make_request(RPCEndpoint('evm_increaseTime'), [seconds]) + if mine: + web3.provider.make_request(RPCEndpoint('evm_mine'), []) return int(res['result']) def set_automining(web3: Web3, value: bool) -> int: - res = web3.provider.make_request('evm_setAutomine', [value]) + res = web3.provider.make_request(RPCEndpoint('evm_setAutomine'), [value]) return int(res['result']) def set_mining_interval(web3: Web3, ms: int) -> int: - res = web3.provider.make_request('evm_setIntervalMining', [ms]) + res = web3.provider.make_request(RPCEndpoint('evm_setIntervalMining'), [ms]) return int(res['result']) -def set_default_mining_interval(web3) -> int: +def set_default_mining_interval(web3: Web3) -> int: return set_mining_interval(web3, DEFAULT_MINING_INTERVAL) -def add_test_permissions(skale): +def add_test_permissions(skale: SkaleManager) -> None: add_all_permissions(skale, skale.wallet.address) -def add_all_permissions(skale, address): +def add_all_permissions(skale: SkaleManager, address: ChecksumAddress) -> None: default_admin_role = skale.manager.default_admin_role() if not skale.manager.has_role(default_admin_role, address): skale.manager.grant_role(default_admin_role, address) @@ -137,7 +144,7 @@ def add_all_permissions(skale, address): skale.slashing_table.grant_role(penalty_setter_role, address) -def add_test2_schain_type(skale) -> TxRes: +def add_test2_schain_type(skale: SkaleManager) -> TxRes: part_of_node = 1 number_of_nodes = 2 return skale.schains_internal.add_schain_type( @@ -145,7 +152,7 @@ def add_test2_schain_type(skale) -> TxRes: ) -def add_test4_schain_type(skale) -> TxRes: +def add_test4_schain_type(skale: SkaleManager) -> TxRes: part_of_node = 1 number_of_nodes = 4 return skale.schains_internal.add_schain_type( @@ -153,7 +160,7 @@ def add_test4_schain_type(skale) -> TxRes: ) -def cleanup_nodes(skale, ids=()): +def cleanup_nodes(skale: SkaleManager, ids: list[NodeId] | None = None) -> None: active_ids = filter( lambda i: skale.nodes.get_node_status(i) == NodeStatus.ACTIVE, ids or skale.nodes.get_active_node_ids() @@ -164,27 +171,27 @@ def cleanup_nodes(skale, ids=()): skale.manager.node_exit(node_id) -def cleanup_schains(skale): +def cleanup_schains(skale: SkaleManager) -> None: for schain_id in skale.schains_internal.get_all_schains_ids(): schain_data = skale.schains.get(schain_id) - schain_name = schain_data.get('name', None) + schain_name = schain_data.name if schain_name is not None: skale.manager.delete_schain_by_root(schain_name, wait_for=True) -def cleanup_nodes_schains(skale): +def cleanup_nodes_schains(skale: SkaleManager) -> None: print('Cleanup nodes and schains') cleanup_schains(skale) cleanup_nodes(skale) -def create_clean_schain(skale): +def create_clean_schain(skale: SkaleManager) -> str: cleanup_nodes_schains(skale) - create_nodes(skale) + create_nodes([skale]) return create_schain(skale, random_name=True) -def create_node(skale) -> str: +def create_node(skale: SkaleManager) -> str: cleanup_nodes_schains(skale) ip, public_ip, port, name = generate_random_node_data() skale.manager.create_node( @@ -198,13 +205,13 @@ def create_node(skale) -> str: return name -def validator_exist(skale): +def validator_exist(skale: SkaleManager) -> bool: return skale.validator_service.validator_address_exists( skale.wallet.address ) -def add_delegation_period(skale): +def add_delegation_period(skale: SkaleManager) -> None: is_added = skale.delegation_period_manager.is_delegation_period_allowed(D_DELEGATION_PERIOD) if not is_added: skale.delegation_period_manager.set_delegation_period( @@ -214,7 +221,7 @@ def add_delegation_period(skale): ) -def setup_validator(skale): +def setup_validator(skale: SkaleManager) -> int: """Create and activate a validator""" set_test_msr(skale) print('Address', skale.wallet.address) @@ -234,7 +241,7 @@ def setup_validator(skale): return validator_id -def link_address_to_validator(skale): +def link_address_to_validator(skale: SkaleManager) -> None: print('Linking address to validator') signature = skale.validator_service.get_link_node_signature(D_VALIDATOR_ID) tx_res = skale.validator_service.link_node_address( @@ -245,7 +252,11 @@ def link_address_to_validator(skale): tx_res.raise_for_status() -def link_nodes_to_validator(skale, validator_id, node_skale_objs=()): +def link_nodes_to_validator( + skale: SkaleManager, + validator_id: ValidatorId, + node_skale_objs: Tuple[SkaleManager] | None = None +) -> None: print('Linking address to validator') node_skale_objs = node_skale_objs or (skale,) validator_id = validator_id or D_VALIDATOR_ID @@ -259,7 +270,7 @@ def link_nodes_to_validator(skale, validator_id, node_skale_objs=()): ) -def skip_delegation_delay(skale, delegation_id): +def skip_delegation_delay(skale: SkaleManager, delegation_id: int) -> None: print(f'Activating delegation with ID {delegation_id}') skale.token_state._skip_transition_delay( delegation_id, @@ -267,7 +278,7 @@ def skip_delegation_delay(skale, delegation_id): ) -def accept_pending_delegation(skale, delegation_id): +def accept_pending_delegation(skale: SkaleManager, delegation_id: int) -> None: print(f'Accepting delegation with ID: {delegation_id}') skale.delegation_controller.accept_pending_delegation( delegation_id=delegation_id, @@ -275,23 +286,23 @@ def accept_pending_delegation(skale, delegation_id): ) -def get_test_delegation_amount(skale): +def get_test_delegation_amount(skale: SkaleManager) -> int: msr = skale.constants_holder.msr() return msr * 30 -def set_test_msr(skale, msr=D_VALIDATOR_MIN_DEL): +def set_test_msr(skale: SkaleManager, msr: int = D_VALIDATOR_MIN_DEL) -> None: skale.constants_holder._set_msr( new_msr=msr, wait_for=True ) -def set_test_mda(skale): +def set_test_mda(skale: SkaleManager) -> None: skale.validator_service.set_validator_mda(0, wait_for=True) -def delegate_to_validator(skale, validator_id=D_VALIDATOR_ID): +def delegate_to_validator(skale: SkaleManager, validator_id: int = D_VALIDATOR_ID) -> None: print(f'Delegating tokens to validator ID: {validator_id}') skale.delegation_controller.delegate( validator_id=validator_id, @@ -302,12 +313,12 @@ def delegate_to_validator(skale, validator_id=D_VALIDATOR_ID): ) -def enable_validator(skale, validator_id=D_VALIDATOR_ID): +def enable_validator(skale: SkaleManager, validator_id: int = D_VALIDATOR_ID) -> None: print(f'Enabling validator ID: {D_VALIDATOR_ID}') skale.validator_service._enable_validator(validator_id, wait_for=True) -def create_validator(skale): +def create_validator(skale: SkaleManager) -> None: print('Creating default validator') skale.validator_service.register_validator( name=D_VALIDATOR_NAME, @@ -318,7 +329,7 @@ def create_validator(skale): ) -def create_nodes(skales, names=()): +def create_nodes(skales: List[SkaleManager], names: List[str] | None = None) -> list[int]: # create couple of nodes print('Creating two nodes') node_names = names or (DEFAULT_NODE_NAME, SECOND_NODE_NAME) @@ -333,19 +344,19 @@ def create_nodes(skales, names=()): wait_for=True ) ids = [ - skale.nodes.node_name_to_index(name) + skales[0].nodes.node_name_to_index(name) for name in node_names ] return ids def create_schain( - skale, - schain_name=DEFAULT_SCHAIN_NAME, - schain_type=1, - random_name=False, - schain_options=None -): + skale: SkaleManager, + schain_name: str = DEFAULT_SCHAIN_NAME, + schain_type: int = 1, + random_name: bool = False, + schain_options: SchainOptions | None = None +) -> str: print('Creating schain') # create 1 s-chain type_of_nodes, lifetime_seconds, name = generate_random_schain_data(skale) diff --git a/skale/utils/contracts_provision/utils.py b/skale/utils/contracts_provision/utils.py index 17e87648..08ef8b71 100644 --- a/skale/utils/contracts_provision/utils.py +++ b/skale/utils/contracts_provision/utils.py @@ -20,27 +20,30 @@ import random import string +from skale.skale_manager import SkaleManager +from skale.types.node import Port -def generate_random_ip(): + +def generate_random_ip() -> str: return '.'.join('%s' % random.randint(0, 255) for i in range(4)) -def generate_random_name(len=8): +def generate_random_name(length: int = 8) -> str: return ''.join( - random.choices(string.ascii_uppercase + string.digits, k=len) + random.choices(string.ascii_uppercase + string.digits, k=length) ) -def generate_random_port(): - return random.randint(0, 60000) +def generate_random_port() -> Port: + return Port(random.randint(0, 60000)) -def generate_random_node_data(): +def generate_random_node_data() -> tuple[str, str, int, str]: return generate_random_ip(), generate_random_ip(), \ generate_random_port(), generate_random_name() -def generate_random_schain_data(skale): +def generate_random_schain_data(skale: SkaleManager) -> tuple[int, int, str]: schain_type = skale.schains_internal.number_of_schain_types() lifetime_seconds = 3600 # 1 hour return schain_type, lifetime_seconds, generate_random_name() diff --git a/skale/utils/exceptions.py b/skale/utils/exceptions.py index df417dd4..b270dbc6 100644 --- a/skale/utils/exceptions.py +++ b/skale/utils/exceptions.py @@ -36,6 +36,6 @@ class SChainNotFoundException(Exception): class InvalidNodeIdError(Exception): """Raised when wrong node id passed""" - def __init__(self, node_id): + def __init__(self, node_id: int): message = f'Node with ID = {node_id} doesn\'t exist!' super().__init__(message) diff --git a/skale/utils/helper.py b/skale/utils/helper.py index e3a69e46..2d52c40c 100644 --- a/skale/utils/helper.py +++ b/skale/utils/helper.py @@ -18,6 +18,8 @@ # along with SKALE.py. If not, see . """ SKALE helper utilities """ +from __future__ import annotations + import ipaddress import json import logging @@ -27,18 +29,38 @@ import sys from logging import Formatter, StreamHandler from random import randint +from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, cast from skale.config import ENV +from skale.types.node import Port + +if TYPE_CHECKING: + from skale.contracts.base_contract import SkaleType + from skale.utils.contract_info import ContractInfo logger = logging.getLogger(__name__) -def decapitalize(s): +def decapitalize(s: str) -> str: return s[:1].lower() + s[1:] if s else '' -def format_fields(fields, flist=False): +WrapperReturnType = Dict[str, Any] | List[Dict[str, Any]] | None + + +def format_fields( + fields: list[str], + flist: bool = False +) -> Callable[ + [ + Callable[ + ..., + List[Any] + ] + ], + Callable[..., WrapperReturnType] +]: """ Transform array to object with passed fields Usage: @@ -49,8 +71,16 @@ def my_method() => {'field_name1': 0, 'field_name2': 'Test'} """ - def real_decorator(function): - def wrapper(*args, **kwargs): + def real_decorator( + function: Callable[ + ..., + List[Any] + ] + ) -> Callable[..., WrapperReturnType]: + def wrapper( + *args: Any, + **kwargs: Any + ) -> WrapperReturnType: result = function(*args, **kwargs) if result is None: @@ -86,7 +116,7 @@ def ip_to_bytes(ip: str) -> bytes: # pragma: no cover return socket.inet_aton(ip) -def is_valid_ipv4_address(address): +def is_valid_ipv4_address(address: str) -> bool: try: ipaddress.IPv4Address(address) except ValueError: @@ -94,43 +124,43 @@ def is_valid_ipv4_address(address): return True -def get_abi(abi_filepath: string = None): +def get_abi(abi_filepath: str | None = None) -> dict[str, Any]: if abi_filepath: with open(abi_filepath, encoding='utf-8') as data_file: - return json.load(data_file) + return cast(dict[str, Any], json.load(data_file)) return {} -def get_skale_manager_address(abi_filepath: string = None) -> str: - return get_abi(abi_filepath)['skale_manager_address'] +def get_skale_manager_address(abi_filepath: str | None = None) -> str: + return cast(str, get_abi(abi_filepath)['skale_manager_address']) -def get_allocator_address(abi_filepath: string = None) -> str: - return get_abi(abi_filepath)['allocator_address'] +def get_allocator_address(abi_filepath: str | None = None) -> str: + return cast(str, get_abi(abi_filepath)['allocator_address']) -def generate_nonce(): # pragma: no cover +def generate_nonce() -> int: # pragma: no cover return randint(0, 65534) -def random_string(size=6, chars=string.ascii_lowercase): # pragma: no cover +def random_string(size: int = 6, chars: str = string.ascii_lowercase) -> str: # pragma: no cover return ''.join(random.choice(chars) for x in range(size)) -def generate_random_ip(): # pragma: no cover +def generate_random_ip() -> str: # pragma: no cover return '.'.join('%s' % random.randint(0, 255) for i in range(4)) -def generate_random_name(len=8): # pragma: no cover +def generate_random_name(length: int = 8) -> str: # pragma: no cover return ''.join( - random.choices(string.ascii_uppercase + string.digits, k=len)) + random.choices(string.ascii_uppercase + string.digits, k=length)) -def generate_random_port(): # pragma: no cover - return random.randint(0, 60000) +def generate_random_port() -> Port: # pragma: no cover + return Port(random.randint(0, 60000)) -def generate_custom_config(ip, ws_port): +def generate_custom_config(ip: str, ws_port: Port) -> dict[str, str | Port]: if not ip or not ws_port: raise ValueError( f'For custom init you should provide ip and ws_port: {ip}, {ws_port}' @@ -141,17 +171,17 @@ def generate_custom_config(ip, ws_port): } -def add_0x_prefix(bytes_string): # pragma: no cover +def add_0x_prefix(bytes_string: str) -> str: # pragma: no cover return '0x' + bytes_string -def rm_0x_prefix(bytes_string): +def rm_0x_prefix(bytes_string: str) -> str: if bytes_string.startswith('0x'): return bytes_string[2:] return bytes_string -def init_default_logger(): # pragma: no cover +def init_default_logger() -> None: # pragma: no cover handlers = [] formatter = Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -164,7 +194,7 @@ def init_default_logger(): # pragma: no cover logging.basicConfig(level=logging.DEBUG, handlers=handlers) -def chunk(in_string, num_chunks): # pragma: no cover +def chunk(in_string: str, num_chunks: int) -> Generator[str, None, None]: # pragma: no cover chunk_size = len(in_string) // num_chunks if len(in_string) % num_chunks: chunk_size += 1 @@ -179,23 +209,25 @@ def chunk(in_string, num_chunks): # pragma: no cover yield ''.join(accumulator) -def split_public_key(public_key: str) -> list: +def split_public_key(public_key: str) -> list[bytes]: public_key = rm_0x_prefix(public_key) pk_parts = list(chunk(public_key, 2)) return list(map(bytes.fromhex, pk_parts)) -def get_contracts_info(contracts_data): +def get_contracts_info( + contracts_data: list[ContractInfo[SkaleType]] +) -> dict[str, ContractInfo[SkaleType]]: contracts_info = {} for contract_info in contracts_data: contracts_info[contract_info.name] = contract_info return contracts_info -def to_camel_case(snake_str): +def to_camel_case(snake_str: str) -> str: components = snake_str.split('_') return components[0] + ''.join(x.title() for x in components[1:]) -def is_test_env(): +def is_test_env() -> bool: return "pytest" in sys.modules or ENV == 'test' diff --git a/skale/utils/random_names/generator.py b/skale/utils/random_names/generator.py index 57263114..d90f6912 100644 --- a/skale/utils/random_names/generator.py +++ b/skale/utils/random_names/generator.py @@ -22,21 +22,21 @@ SOUND_AND_APPEARANCE_ADJECTIVES, POSITIVE_AND_TIME_ADJECTIVES -def generate_random_node_name(): +def generate_random_node_name() -> str: return generate_name(CONSTELLATIONS, POSITIVE_AND_TIME_ADJECTIVES) -def generate_random_schain_name(): +def generate_random_schain_name() -> str: return generate_name(STARS, SOUND_AND_APPEARANCE_ADJECTIVES) -def generate_name(noun_dict, adjective_dict): +def generate_name(noun_dict: list[str], adjective_dict: list[str]) -> str: noun = get_random_word_from_dict(noun_dict) adjective = get_random_word_from_dict(adjective_dict) return f'{adjective}-{noun}' -def get_random_word_from_dict(vocabulary_dict): +def get_random_word_from_dict(vocabulary_dict: list[str]) -> str: return random.choice(vocabulary_dict).lower().replace(" ", "-") diff --git a/skale/utils/web3_utils.py b/skale/utils/web3_utils.py index eae748ea..06fd6ce7 100644 --- a/skale/utils/web3_utils.py +++ b/skale/utils/web3_utils.py @@ -21,16 +21,26 @@ import logging import os import time -from typing import Iterable +from typing import Any, Callable, Dict, Iterable from urllib.parse import urlparse -from eth_keys import keys +from eth_keys.main import lazy_key_api as keys +from eth_typing import Address, AnyAddress, BlockNumber, ChecksumAddress, HexStr from web3 import Web3, WebsocketProvider, HTTPProvider from web3.exceptions import TransactionNotFound -from web3.middleware import ( - attrdict_middleware, - geth_poa_middleware, - http_retry_request_middleware +from web3.middleware.attrdict import attrdict_middleware +from web3.middleware.exception_retry_request import http_retry_request_middleware +from web3.middleware.geth_poa import geth_poa_middleware +from web3.providers.base import JSONBaseProvider +from web3.types import ( + _Hash32, + ENS, + Middleware, + Nonce, + RPCEndpoint, + RPCResponse, + Timestamp, + TxReceipt ) import skale.config as config @@ -50,7 +60,11 @@ DEFAULT_BLOCKS_TO_WAIT = 50 -def get_provider(endpoint, timeout=DEFAULT_HTTP_TIMEOUT, request_kwargs={}): +def get_provider( + endpoint: str, + timeout: int = DEFAULT_HTTP_TIMEOUT, + request_kwargs: Dict[str, Any] | None = None +) -> JSONBaseProvider: scheme = urlparse(endpoint).scheme if scheme == 'ws' or scheme == 'wss': kwargs = request_kwargs or {'max_size': WS_MAX_MESSAGE_DATA_BYTES} @@ -58,7 +72,7 @@ def get_provider(endpoint, timeout=DEFAULT_HTTP_TIMEOUT, request_kwargs={}): websocket_kwargs=kwargs) if scheme == 'http' or scheme == 'https': - kwargs = {'timeout': timeout, **request_kwargs} + kwargs = {'timeout': timeout, **(request_kwargs or {})} return HTTPProvider(endpoint, request_kwargs=kwargs) raise Exception( @@ -87,21 +101,39 @@ def save_last_known_block_number(state_path: str, block_number: int) -> None: last_block_file.write(str(block_number)) -def outdated_client_time_msg(method, current_time, latest_block_timestamp, allowed_ts_diff): +def outdated_client_time_msg( + method: RPCEndpoint, + current_time: float, + latest_block_timestamp: Timestamp, + allowed_ts_diff: int +) -> str: return f'{method} failed; \ current_time: {current_time}, latest_block_timestamp: {latest_block_timestamp}, \ allowed_ts_diff: {allowed_ts_diff}' -def outdated_client_file_msg(method, latest_block_number, saved_number, state_path): +def outdated_client_file_msg( + method: RPCEndpoint, + latest_block_number: BlockNumber, + saved_number: int, + state_path: str +) -> str: return f'{method} failed: latest_block_number: {latest_block_number}, \ saved_number: {saved_number}, state_path: {state_path}' -def make_client_checking_middleware(allowed_ts_diff: int, - state_path: str = None): - def eth_client_checking_middleware(make_request, web3): - def middleware(method, params): +def make_client_checking_middleware( + allowed_ts_diff: int, + state_path: str | None = None +) -> Callable[ + [Callable[[RPCEndpoint, Any], RPCResponse], Web3], + Callable[[RPCEndpoint, Any], RPCResponse] +]: + def eth_client_checking_middleware( + make_request: Callable[[RPCEndpoint, Any], RPCResponse], + web3: Web3 + ) -> Callable[[RPCEndpoint, Any], RPCResponse]: + def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: if method in ('eth_block_number', 'eth_getBlockByNumber'): response = make_request(method, params) else: @@ -139,23 +171,23 @@ def middleware(method, params): def init_web3(endpoint: str, provider_timeout: int = DEFAULT_HTTP_TIMEOUT, - middlewares: Iterable = None, - state_path: str = None, ts_diff: int = None): + middlewares: Iterable[Middleware] | None = None, + state_path: str | None = None, ts_diff: int | None = None) -> Web3: if not middlewares: ts_diff = ts_diff or config.ALLOWED_TS_DIFF state_path = state_path or config.LAST_BLOCK_FILE if not ts_diff == config.NO_SYNC_TS_DIFF: sync_middleware = make_client_checking_middleware(ts_diff, state_path) - middewares = ( + middewares = [ http_retry_request_middleware, sync_middleware, attrdict_middleware - ) + ] else: - middewares = ( + middewares = [ http_retry_request_middleware, attrdict_middleware - ) + ] provider = get_provider(endpoint, timeout=provider_timeout) web3 = Web3(provider) @@ -166,20 +198,20 @@ def init_web3(endpoint: str, return web3 -def get_receipt(web3, tx): +def get_receipt(web3: Web3, tx: _Hash32) -> TxReceipt: return web3.eth.get_transaction_receipt(tx) -def get_eth_nonce(web3, address): +def get_eth_nonce(web3: Web3, address: Address | ChecksumAddress | ENS) -> Nonce: return web3.eth.get_transaction_count(address) def wait_for_receipt_by_blocks( - web3, - tx, - blocks_to_wait=DEFAULT_BLOCKS_TO_WAIT, - timeout=MAX_WAITING_TIME -): + web3: Web3, + tx: _Hash32, + blocks_to_wait: int = DEFAULT_BLOCKS_TO_WAIT, + timeout: int = MAX_WAITING_TIME +) -> TxReceipt: blocks_to_wait = blocks_to_wait or DEFAULT_BLOCKS_TO_WAIT timeout = timeout or MAX_WAITING_TIME previous_block = web3.eth.block_number @@ -196,11 +228,11 @@ def wait_for_receipt_by_blocks( current_block = web3.eth.block_number time.sleep(3) raise TransactionNotMinedError( - f'Transaction with hash: {tx} not found in {blocks_to_wait} blocks.' + f'Transaction with hash: {str(tx)} not found in {blocks_to_wait} blocks.' ) -def wait_receipt(web3, tx, retries=30, timeout=5): +def wait_receipt(web3: Web3, tx: _Hash32, retries: int = 30, timeout: int = 5) -> TxReceipt: for _ in range(0, retries): try: receipt = get_receipt(web3, tx) @@ -210,11 +242,11 @@ def wait_receipt(web3, tx, retries=30, timeout=5): return receipt time.sleep(timeout) # pragma: no cover raise TransactionNotMinedError( - f'Transaction with hash: {tx} not mined after {retries} retries.' + f'Transaction with hash: {str(tx)} not mined after {retries} retries.' ) -def check_receipt(receipt, raise_error=True): +def check_receipt(receipt: TxReceipt, raise_error: bool = True) -> bool: if receipt['status'] != 1: # pragma: no cover if raise_error: raise TransactionFailedError( @@ -226,11 +258,11 @@ def check_receipt(receipt, raise_error=True): def wait_for_confirmation_blocks( - web3, - blocks_to_wait, - timeout=MAX_WAITING_TIME, - request_timeout=5 -): + web3: Web3, + blocks_to_wait: int, + timeout: int = MAX_WAITING_TIME, + request_timeout: int = 5 +) -> None: current_block = start_block = web3.eth.block_number logger.info( f'Current block number is {current_block}, ' @@ -243,32 +275,26 @@ def wait_for_confirmation_blocks( time.sleep(request_timeout) -def private_key_to_public(pr): +def private_key_to_public(pr: HexStr) -> HexStr: pr_bytes = Web3.to_bytes(hexstr=pr) - pk = keys.PrivateKey(pr_bytes) - return pk.public_key + prk = keys.PrivateKey(pr_bytes) + pk = prk.public_key + return HexStr(pk.to_hex()) -def public_key_to_address(pk): +def public_key_to_address(pk: HexStr) -> HexStr: hash = Web3.keccak(hexstr=str(pk)) return Web3.to_hex(hash[-20:]) -def private_key_to_address(pr): +def private_key_to_address(pr: HexStr) -> HexStr: pk = private_key_to_public(pr) return public_key_to_address(pk) -def to_checksum_address(address): +def to_checksum_address(address: AnyAddress | str | bytes) -> ChecksumAddress: return Web3.to_checksum_address(address) -def wallet_to_public_key(wallet): - if isinstance(wallet, dict): - return private_key_to_public(wallet['private_key']) - else: - return wallet['public_key'] - - def default_gas_price(web3: Web3) -> int: return web3.eth.gas_price * GAS_PRICE_COEFFICIENT diff --git a/skale/wallets/__init__.py b/skale/wallets/__init__.py index a5d7b685..dcced7a2 100644 --- a/skale/wallets/__init__.py +++ b/skale/wallets/__init__.py @@ -5,3 +5,9 @@ from skale.wallets.redis_wallet import RedisWalletAdapter from skale.wallets.sgx_wallet import SgxWallet from skale.wallets.web3_wallet import Web3Wallet + +__all__ = [ + 'BaseWallet', + 'LedgerWallet', + 'Web3Wallet' +] diff --git a/skale/wallets/common.py b/skale/wallets/common.py index bb97ee8c..9bc8ff38 100644 --- a/skale/wallets/common.py +++ b/skale/wallets/common.py @@ -18,12 +18,18 @@ # along with SKALE.py. If not, see . from abc import ABC, abstractmethod -from typing import Dict, Optional +from typing import Optional + +from eth_account.datastructures import SignedMessage, SignedTransaction +from eth_typing import ChecksumAddress, HexStr +from web3 import Web3 +from web3.types import _Hash32, TxParams, TxReceipt from skale.transactions.exceptions import ChainIdError +from skale.utils.web3_utils import DEFAULT_BLOCKS_TO_WAIT -def ensure_chain_id(tx_dict, web3): +def ensure_chain_id(tx_dict: TxParams, web3: Web3) -> None: if not tx_dict.get('chainId'): tx_dict['chainId'] = web3.eth.chain_id if not tx_dict.get('chainId'): @@ -39,26 +45,26 @@ class MessageNotSignedError(Exception): class BaseWallet(ABC): @abstractmethod - def sign(self, tx): + def sign(self, tx_dict: TxParams) -> SignedTransaction: pass @abstractmethod def sign_and_send( self, - tx_dict: Dict, - multiplier: Optional[int] = None, + tx_dict: TxParams, + multiplier: Optional[float] = None, priority: Optional[int] = None, method: Optional[str] = None - ) -> str: + ) -> HexStr: pass @abstractmethod - def sign_hash(self, unsigned_hash: str) -> str: + def sign_hash(self, unsigned_hash: str) -> SignedMessage: pass @property @abstractmethod - def address(self) -> str: + def address(self) -> ChecksumAddress: pass @property @@ -67,5 +73,5 @@ def public_key(self) -> str: pass @abstractmethod - def wait(self, tx: str, confirmation_blocks: int = None): + def wait(self, tx: _Hash32, confirmation_blocks: int = DEFAULT_BLOCKS_TO_WAIT) -> TxReceipt: pass diff --git a/skale/wallets/ledger_wallet.py b/skale/wallets/ledger_wallet.py index 6d2638e0..096a364b 100644 --- a/skale/wallets/ledger_wallet.py +++ b/skale/wallets/ledger_wallet.py @@ -19,21 +19,31 @@ import logging import struct -from typing import Dict, Optional +from typing import Generator, Tuple, cast +from eth_typing import ChecksumAddress, HexStr from hexbytes import HexBytes -from eth_account.datastructures import SignedTransaction -from eth_account._utils.legacy_transactions import encode_transaction -from eth_account._utils.legacy_transactions import \ - serializable_unsigned_transaction_from_dict as tx_from_dict +from eth_account.datastructures import SignedMessage, SignedTransaction +from eth_account._utils.legacy_transactions import ( + encode_transaction, + serializable_unsigned_transaction_from_dict as tx_from_dict, + Transaction, + UnsignedTransaction +) +from eth_account._utils.typed_transactions import TypedTransaction from eth_utils.crypto import keccak from rlp import encode +from web3 import Web3 +from web3.contract.contract import ContractFunction from web3.exceptions import Web3Exception +from web3.types import _Hash32, TxParams, TxReceipt import skale.config as config from skale.transactions.exceptions import TransactionNotSentError, TransactionNotSignedError from skale.utils.web3_utils import ( + DEFAULT_BLOCKS_TO_WAIT, + MAX_WAITING_TIME, get_eth_nonce, public_key_to_address, to_checksum_address, @@ -48,7 +58,7 @@ class LedgerCommunicationError(Exception): pass -def encode_bip32_path(path): +def encode_bip32_path(path: str) -> bytes: if len(path) == 0: return b'' encoded_chunks = [] @@ -63,27 +73,27 @@ def encode_bip32_path(path): return b''.join(encoded_chunks) -def derivation_path_prefix(bin32_path): +def derivation_path_prefix(bin32_path: str) -> bytes: encoded_path = encode_bip32_path(bin32_path) encoded_path_len_bytes = (len(encoded_path) // 4).to_bytes(1, 'big') return encoded_path_len_bytes + encoded_path -def chunks(sequence, size): +def chunks(sequence: bytes, size: int) -> Generator[bytes, None, None]: return (sequence[pos:pos + size] for pos in range(0, len(sequence), size)) -def get_derivation_path(address_index, legacy) -> str: +def get_derivation_path(address_index: int, legacy: bool) -> str: if legacy: return get_legacy_derivation_path(address_index) return get_live_derivation_path(address_index) -def get_live_derivation_path(address_index) -> str: +def get_live_derivation_path(address_index: int) -> str: return f'44\'/60\'/{address_index}\'/0/0' -def get_legacy_derivation_path(address_index) -> str: +def get_legacy_derivation_path(address_index: int) -> str: return f'44\'/60\'/0\'/{address_index}' @@ -91,7 +101,7 @@ class LedgerWallet(BaseWallet): CHUNK_SIZE = 255 CLA = b'\xe0' - def __init__(self, web3, address_index, legacy=False, debug=False): + def __init__(self, web3: Web3, address_index: int, legacy: bool = False, debug: bool = False): from ledgerblue.comm import getDongle from ledgerblue.commException import CommException @@ -108,25 +118,29 @@ def __init__(self, web3, address_index, legacy=False, debug=False): ) @property - def address(self): + def address(self) -> ChecksumAddress: return self._address @property - def public_key(self): + def public_key(self) -> str: return self._public_key # todo: remove this method after making software wallet as class - def __getitem__(self, key): + def __getitem__(self, key: str) -> str: items = {'address': self.address, 'public_key': self.public_key} return items[key] - def make_payload(self, data=''): - encoded_data = encode(data) + def make_payload(self, data: str = '') -> bytes: + encoded_data = cast(bytes, encode(data)) path_prefix = derivation_path_prefix(self._bip32_path) return path_prefix + encoded_data @classmethod - def parse_sign_result(cls, tx, exchange_result): + def parse_sign_result( + cls, + tx: TypedTransaction | Transaction | UnsignedTransaction, + exchange_result: bytearray | bytes + ) -> SignedTransaction: sign_v = exchange_result[0] sign_r = int((exchange_result[1:1 + 32]).hex(), 16) sign_s = int((exchange_result[1 + 32: 1 + 32 + 32]).hex(), 16) @@ -141,7 +155,7 @@ def parse_sign_result(cls, tx, exchange_result): s=sign_s ) - def exchange_sign_payload_by_chunks(self, payload): + def exchange_sign_payload_by_chunks(self, payload: bytes) -> bytearray: INS = b'\x04' P1_FIRST = b'\x00' P1_SUBSEQUENT = b'\x80' @@ -155,9 +169,9 @@ def exchange_sign_payload_by_chunks(self, payload): ]) exchange_result = self.dongle.exchange(apdu) p1 = P1_SUBSEQUENT - return exchange_result + return cast(bytearray, exchange_result) - def sign(self, tx_dict): + def sign(self, tx_dict: TxParams) -> SignedTransaction: ensure_chain_id(tx_dict, self._web3) if tx_dict.get('nonce') is None: tx_dict['nonce'] = self._web3.eth.get_transaction_count(self.address) @@ -172,34 +186,33 @@ def sign(self, tx_dict): def sign_and_send( self, - tx: Dict, - multiplier: int = config.DEFAULT_GAS_MULTIPLIER, - priority: int = config.DEFAULT_PRIORITY, - method: Optional[str] = None, - meta: Optional[Dict] = None - ) -> str: + tx: TxParams, + multiplier: float | None = config.DEFAULT_GAS_MULTIPLIER, + priority: int | None = config.DEFAULT_PRIORITY, + method: str | None = None + ) -> HexStr: signed_tx = self.sign(tx) try: - return self._web3.eth.send_raw_transaction( + return Web3.to_hex(self._web3.eth.send_raw_transaction( signed_tx.rawTransaction - ).hex() + )) except (ValueError, Web3Exception) as e: raise TransactionNotSentError(e) - def sign_hash(self, unsigned_hash: str): + def sign_hash(self, unsigned_hash: str) -> SignedMessage: raise NotImplementedError( 'sign_hash is not implemented for hardware wallet' ) @classmethod - def parse_derive_result(cls, exchange_result): + def parse_derive_result(cls, exchange_result: bytearray) -> Tuple[ChecksumAddress, str]: pk_len = exchange_result[0] - pk = exchange_result[1: pk_len + 1].hex()[2:] + pk = HexStr(exchange_result[1: pk_len + 1].hex()[2:]) address = public_key_to_address(pk) checksum_address = to_checksum_address(address) return checksum_address, pk - def exchange_derive_payload(self, payload): + def exchange_derive_payload(self, payload: bytes) -> bytearray: INS = b'\x02' P1 = b'\x00' P2 = b'\x00' @@ -208,14 +221,19 @@ def exchange_derive_payload(self, payload): LedgerWallet.CLA, INS, P1, P2, payload_size_in_bytes, payload ]) - return self.dongle.exchange(apdu) + return cast(bytearray, self.dongle.exchange(apdu)) - def get_address_with_public_key(self): + def get_address_with_public_key(self) -> tuple[ChecksumAddress, str]: payload = self.make_payload() exchange_result = self.exchange_derive_payload(payload) return LedgerWallet.parse_derive_result(exchange_result) - def wait(self, tx_hash: str, blocks_to_wait=None, timeout=None): + def wait( + self, + tx_hash: _Hash32, + blocks_to_wait: int = DEFAULT_BLOCKS_TO_WAIT, + timeout: int = MAX_WAITING_TIME + ) -> TxReceipt: return wait_for_receipt_by_blocks( self._web3, tx_hash, @@ -224,8 +242,13 @@ def wait(self, tx_hash: str, blocks_to_wait=None, timeout=None): ) -def hardware_sign_and_send(web3, method, gas_amount, wallet) -> str: - address_from = wallet['address'] +def hardware_sign_and_send( + web3: Web3, + method: ContractFunction, + gas_amount: int, + wallet: LedgerWallet +) -> str: + address_from = wallet.address eth_nonce = get_eth_nonce(web3, address_from) tx_dict = method.build_transaction({ 'gas': gas_amount, @@ -234,6 +257,6 @@ def hardware_sign_and_send(web3, method, gas_amount, wallet) -> str: signed_txn = wallet.sign(tx_dict) tx = web3.eth.send_raw_transaction(signed_txn.rawTransaction).hex() logger.info( - f'{method.__class__.__name__} - transaction_hash: {web3.to_hex(tx)}' + f'{method.__class__.__name__} - transaction_hash: {tx}' ) return tx diff --git a/skale/wallets/redis_wallet.py b/skale/wallets/redis_wallet.py index 2d68b71f..95b34f6c 100644 --- a/skale/wallets/redis_wallet.py +++ b/skale/wallets/redis_wallet.py @@ -23,9 +23,13 @@ import os import time from enum import Enum -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple, TypedDict +from eth_account.datastructures import SignedMessage, SignedTransaction +from eth_typing import ChecksumAddress, HexStr from redis import Redis +from web3 import Web3 +from web3.types import _Hash32, TxParams, TxReceipt import skale.config as config from skale.transactions.exceptions import ( @@ -34,8 +38,9 @@ TransactionNotSentError, TransactionWaitError ) -from skale.utils.web3_utils import get_receipt, MAX_WAITING_TIME +from skale.utils.web3_utils import DEFAULT_BLOCKS_TO_WAIT, get_receipt, MAX_WAITING_TIME from skale.wallets import BaseWallet +from skale.wallets.web3_wallet import Web3Wallet logger = logging.getLogger(__name__) @@ -69,6 +74,15 @@ def __str__(self) -> str: return str.__str__(self) +TxRecord = TypedDict( + 'TxRecord', + { + 'status': TxRecordStatus, + 'tx_hash': HexStr + }, +) + + class RedisWalletAdapter(BaseWallet): ID_SIZE = 16 @@ -76,20 +90,20 @@ def __init__( self, rs: Redis, pool: str, - base_wallet: BaseWallet, + web3_wallet: Web3Wallet, ) -> None: self.rs = rs self.pool = pool - self.wallet = base_wallet + self.wallet = web3_wallet - def sign(self, tx: Dict) -> Dict: + def sign(self, tx: TxParams) -> SignedTransaction: return self.wallet.sign(tx) - def sign_hash(self, unsigned_hash: str) -> str: + def sign_hash(self, unsigned_hash: str) -> SignedMessage: return self.wallet.sign_hash(unsigned_hash) @property - def address(self) -> str: + def address(self) -> ChecksumAddress: return self.wallet.address @property @@ -105,16 +119,15 @@ def _make_raw_id(cls) -> bytes: @classmethod def _make_score(cls, priority: int) -> int: ts = int(time.time()) - return priority * 10 ** len(str(ts)) + ts + return priority * int(10 ** len(str(ts))) + ts @classmethod def _make_record( cls, - tx: Dict, + tx: TxParams, score: int, - multiplier: int = config.DEFAULT_GAS_MULTIPLIER, - method: Optional[str] = None, - meta: Optional[Dict] = None + multiplier: float = config.DEFAULT_GAS_MULTIPLIER, + method: Optional[str] = None ) -> Tuple[bytes, bytes]: tx_id = cls._make_raw_id() params = { @@ -123,7 +136,6 @@ def _make_record( 'multiplier': multiplier, 'tx_hash': None, 'method': method, - 'meta': meta, **tx } # Ensure gas will be restimated in TM @@ -132,20 +144,22 @@ def _make_record( return tx_id, record @classmethod - def _to_raw_id(cls, tx_id: str) -> bytes: - return tx_id.encode('utf-8') + def _to_raw_id(cls, tx_id: _Hash32) -> bytes: + if isinstance(tx_id, str): + return Web3.to_bytes(hexstr=tx_id) + return Web3.to_bytes(tx_id) - def _to_id(cls, raw_id: str) -> str: - return raw_id.decode('utf-8') + @classmethod + def _to_id(cls, raw_id: bytes) -> HexStr: + return Web3.to_hex(raw_id) def sign_and_send( self, - tx: Dict, + tx: TxParams, multiplier: Optional[float] = None, priority: Optional[int] = None, - method: Optional[str] = None, - meta: Optional[Dict] = None - ) -> str: + method: Optional[str] = None + ) -> HexStr: priority = priority or config.DEFAULT_PRIORITY try: logger.info('Sending %s to redis pool, method: %s', tx, method) @@ -153,9 +167,8 @@ def sign_and_send( raw_id, tx_record = self._make_record( tx, score, - multiplier=multiplier, - method=method, - meta=meta + multiplier=multiplier or config.DEFAULT_GAS_MULTIPLIER, + method=method ) pipe = self.rs.pipeline() logger.info('Adding tx %s to the pool', raw_id) @@ -168,19 +181,26 @@ def sign_and_send( logger.exception(f'Sending {tx} with redis wallet errored') raise RedisWalletNotSentError(err) - def get_status(self, tx_id: str) -> str: + def get_status(self, tx_id: _Hash32) -> str: return self.get_record(tx_id)['status'] - def get_record(self, tx_id: str) -> Dict: + def get_record(self, tx_id: _Hash32) -> TxRecord: rid = self._to_raw_id(tx_id) - return json.loads(self.rs.get(rid).decode('utf-8')) + response = self.rs.get(rid) + if isinstance(response, bytes): + parsed_json = json.loads(response.decode('utf-8')) + return TxRecord({ + 'status': parsed_json['status'], + 'tx_hash': parsed_json['tx_hash'] + }) + raise ValueError('Unknown value was returned from get() call', response) def wait( self, - tx_id: str, - blocks_to_wait: Optional[int] = None, + tx_id: _Hash32, + blocks_to_wait: int = DEFAULT_BLOCKS_TO_WAIT, timeout: int = MAX_WAITING_TIME - ) -> Dict: + ) -> TxReceipt: start_ts = time.time() status, result = None, None while status not in [ diff --git a/skale/wallets/sgx_wallet.py b/skale/wallets/sgx_wallet.py index c2053bce..478a1d76 100644 --- a/skale/wallets/sgx_wallet.py +++ b/skale/wallets/sgx_wallet.py @@ -18,15 +18,23 @@ # along with SKALE.py. If not, see . import logging -from typing import Dict, Optional +from typing import Tuple, cast +from eth_account.datastructures import SignedMessage, SignedTransaction +from eth_typing import ChecksumAddress, HexStr from sgx import SgxClient from web3 import Web3 from web3.exceptions import Web3Exception +from web3.types import _Hash32, TxParams, TxReceipt import skale.config as config from skale.transactions.exceptions import TransactionNotSentError, TransactionNotSignedError -from skale.utils.web3_utils import get_eth_nonce, wait_for_receipt_by_blocks +from skale.utils.web3_utils import ( + DEFAULT_BLOCKS_TO_WAIT, + MAX_WAITING_TIME, + get_eth_nonce, + wait_for_receipt_by_blocks +) from skale.wallets.common import BaseWallet, ensure_chain_id, MessageNotSignedError @@ -34,7 +42,13 @@ class SgxWallet(BaseWallet): - def __init__(self, sgx_endpoint, web3, key_name=None, path_to_cert=None): + def __init__( + self, + sgx_endpoint: str, + web3: Web3, + key_name: str | None = None, + path_to_cert: str | None = None + ): self.sgx_client = SgxClient(sgx_endpoint, path_to_cert=path_to_cert) self._web3 = web3 if key_name is None: @@ -43,32 +57,31 @@ def __init__(self, sgx_endpoint, web3, key_name=None, path_to_cert=None): self._key_name = key_name self._address, self._public_key = self._get_account(key_name) - def sign(self, tx_dict): + def sign(self, tx_dict: TxParams) -> SignedTransaction: if tx_dict.get('nonce') is None: tx_dict['nonce'] = get_eth_nonce(self._web3, self._address) ensure_chain_id(tx_dict, self._web3) try: - return self.sgx_client.sign(tx_dict, self.key_name) + return cast(SignedTransaction, self.sgx_client.sign(tx_dict, self.key_name)) except Exception as e: raise TransactionNotSignedError(e) def sign_and_send( self, - tx_dict: Dict, - multiplier: int = config.DEFAULT_GAS_MULTIPLIER, - priority: int = config.DEFAULT_PRIORITY, - method: Optional[str] = None, - meta: Optional[Dict] = None - ) -> str: + tx_dict: TxParams, + multiplier: float | None = config.DEFAULT_GAS_MULTIPLIER, + priority: int | None = config.DEFAULT_PRIORITY, + method: str | None = None + ) -> HexStr: signed_tx = self.sign(tx_dict) try: - return self._web3.eth.send_raw_transaction( + return Web3.to_hex(self._web3.eth.send_raw_transaction( signed_tx.rawTransaction - ).hex() + )) except (ValueError, Web3Exception) as e: raise TransactionNotSentError(e) - def sign_hash(self, unsigned_hash: str): + def sign_hash(self, unsigned_hash: str) -> SignedMessage: if unsigned_hash.startswith('0x'): unsigned_hash = unsigned_hash[2:] @@ -78,35 +91,43 @@ def sign_hash(self, unsigned_hash: str): hash_to_sign = Web3.keccak(hexstr='0x' + normalized_hash.hex()) chain_id = None try: - return self.sgx_client.sign_hash( - hash_to_sign, - self._key_name, - chain_id + return cast( + SignedMessage, + self.sgx_client.sign_hash( + hash_to_sign, + self._key_name, + chain_id + ) ) except Exception as e: raise MessageNotSignedError(e) @property - def address(self): + def address(self) -> ChecksumAddress: return self._address @property - def public_key(self): + def public_key(self) -> str: return self._public_key @property - def key_name(self): + def key_name(self) -> str: return self._key_name - def _generate(self): + def _generate(self) -> Tuple[str, ChecksumAddress, str]: key = self.sgx_client.generate_key() - return key.name, key.address, key.public_key + return key.name, Web3.to_checksum_address(key.address), key.public_key - def _get_account(self, key_name): + def _get_account(self, key_name: str) -> Tuple[ChecksumAddress, str]: account = self.sgx_client.get_account(key_name) - return account.address, account.public_key - - def wait(self, tx_hash: str, blocks_to_wait=None, timeout=None): + return Web3.to_checksum_address(account.address), account.public_key + + def wait( + self, + tx_hash: _Hash32, + blocks_to_wait: int = DEFAULT_BLOCKS_TO_WAIT, + timeout: int = MAX_WAITING_TIME + ) -> TxReceipt: return wait_for_receipt_by_blocks( self._web3, tx_hash, diff --git a/skale/wallets/web3_wallet.py b/skale/wallets/web3_wallet.py index 64b9faf5..9b822e56 100644 --- a/skale/wallets/web3_wallet.py +++ b/skale/wallets/web3_wallet.py @@ -17,10 +17,14 @@ # You should have received a copy of the GNU Affero General Public License # along with SKALE.py. If not, see . -from typing import Dict, Optional -from eth_keys import keys +from typing import cast +from eth_keys.main import lazy_key_api as keys +from eth_keys.datatypes import PublicKey from web3 import Web3 +from web3.types import _Hash32, TxParams, TxReceipt from eth_account import messages +from eth_account.datastructures import SignedMessage, SignedTransaction +from eth_typing import AnyAddress, ChecksumAddress, HexStr from web3.exceptions import Web3Exception import skale.config as config @@ -28,94 +32,109 @@ TransactionNotSignedError, TransactionNotSentError ) -from skale.utils.web3_utils import get_eth_nonce, wait_for_receipt_by_blocks +from skale.utils.web3_utils import ( + DEFAULT_BLOCKS_TO_WAIT, + MAX_WAITING_TIME, + get_eth_nonce, + wait_for_receipt_by_blocks +) from skale.wallets.common import BaseWallet, ensure_chain_id, MessageNotSignedError -def private_key_to_public(pr): +def private_key_to_public(pr: HexStr) -> PublicKey: pr_bytes = Web3.to_bytes(hexstr=pr) pk = keys.PrivateKey(pr_bytes) return pk.public_key -def public_key_to_address(pk): +def public_key_to_address(pk: PublicKey) -> ChecksumAddress: hash = Web3.keccak(hexstr=str(pk)) return to_checksum_address(Web3.to_hex(hash[-20:])) -def private_key_to_address(pr): +def private_key_to_address(pr: HexStr) -> ChecksumAddress: pk = private_key_to_public(pr) return public_key_to_address(pk) -def to_checksum_address(address): +def to_checksum_address(address: AnyAddress | str | bytes) -> ChecksumAddress: return Web3.to_checksum_address(address) -def generate_wallet(web3): - account = web3.eth.account.create() - private_key = account.key.hex() - return Web3Wallet(private_key, web3) - - class Web3Wallet(BaseWallet): - def __init__(self, private_key, web3): + def __init__(self, private_key: HexStr, web3: Web3): self._private_key = private_key self._public_key = private_key_to_public(self._private_key) self._address = public_key_to_address(self._public_key) self._web3 = web3 - def sign(self, tx_dict): + def sign(self, tx_dict: TxParams) -> SignedTransaction: if not tx_dict.get('nonce'): tx_dict['nonce'] = get_eth_nonce(self._web3, self._address) ensure_chain_id(tx_dict, self._web3) try: - return self._web3.eth.account.sign_transaction( - tx_dict, - private_key=self._private_key + return cast( + SignedTransaction, + self._web3.eth.account.sign_transaction( + tx_dict, + private_key=self._private_key + ) ) except (TypeError, ValueError, Web3Exception) as e: raise TransactionNotSignedError(e) - def sign_hash(self, unsigned_hash: str): + def sign_hash(self, unsigned_hash: str) -> SignedMessage: try: unsigned_message = messages.encode_defunct(hexstr=unsigned_hash) - return self._web3.eth.account.sign_message( - unsigned_message, - private_key=self._private_key + return cast( + SignedMessage, + self._web3.eth.account.sign_message( + unsigned_message, + private_key=self._private_key + ) ) except (TypeError, ValueError, Web3Exception) as e: raise MessageNotSignedError(e) def sign_and_send( self, - tx_dict: Dict, - multiplier: int = config.DEFAULT_GAS_MULTIPLIER, - priority: int = config.DEFAULT_PRIORITY, - method: Optional[str] = None, - meta: Optional[Dict] = None - ) -> str: + tx_dict: TxParams, + multiplier: float | None = config.DEFAULT_GAS_MULTIPLIER, + priority: int | None = config.DEFAULT_PRIORITY, + method: str | None = None + ) -> HexStr: signed_tx = self.sign(tx_dict) try: - return self._web3.eth.send_raw_transaction( + return Web3.to_hex(self._web3.eth.send_raw_transaction( signed_tx.rawTransaction - ).hex() + )) except (ValueError, Web3Exception) as e: raise TransactionNotSentError(e) @property - def address(self): + def address(self) -> ChecksumAddress: return self._address @property - def public_key(self): + def public_key(self) -> str: return str(self._public_key) - def wait(self, tx_hash: str, blocks_to_wait=None, timeout=None): + def wait( + self, + tx_hash: _Hash32, + blocks_to_wait: int = DEFAULT_BLOCKS_TO_WAIT, + timeout: int = MAX_WAITING_TIME + ) -> TxReceipt: return wait_for_receipt_by_blocks( self._web3, tx_hash, blocks_to_wait=blocks_to_wait, timeout=timeout ) + + +def generate_wallet(web3: Web3) -> Web3Wallet: + account = web3.eth.account.create() + private_key = account.key.hex() + return Web3Wallet(private_key, web3) diff --git a/tests/allocator/escrow_test.py b/tests/allocator/escrow_test.py index 89eda5c7..1d44db54 100644 --- a/tests/allocator/escrow_test.py +++ b/tests/allocator/escrow_test.py @@ -1,5 +1,6 @@ """ Tests for skale/allocator/escrow.py """ +from skale.types.delegation import DelegationStatus from skale.wallets.web3_wallet import generate_wallet from skale.utils.account_tools import send_eth, check_skale_balance @@ -87,7 +88,7 @@ def test_request_undelegate(skale, skale_allocator): validator_id=D_VALIDATOR_ID ) assert delegations[-1]['id'] == delegation_id - assert delegations[-1]['status'] == 'UNDELEGATION_REQUESTED' + assert delegations[-1]['status'] == DelegationStatus.UNDELEGATION_REQUESTED def test_retrieve(skale, skale_allocator): @@ -169,4 +170,4 @@ def test_cancel_pending_delegation(skale_allocator, skale): validator_id=D_VALIDATOR_ID ) assert delegations[-1]['id'] == delegation_id - assert delegations[-1]['status'] == 'CANCELED' + assert delegations[-1]['status'] == DelegationStatus.CANCELED diff --git a/tests/main_test.py b/tests/main_test.py index f4408e30..45d6a445 100644 --- a/tests/main_test.py +++ b/tests/main_test.py @@ -82,8 +82,8 @@ def test_get_contract_address(skale): def test_get_attr(skale): - random_attr = skale.t123_random_attr - assert random_attr is None + with pytest.raises(ValueError): + skale.t123_random_attr skale_py_nodes_contract = skale.nodes assert issubclass(type(skale_py_nodes_contract), BaseContract) assert isinstance(skale_py_nodes_contract, Nodes) diff --git a/tests/manager/delegation/delegation_controller_test.py b/tests/manager/delegation/delegation_controller_test.py index e8ec4249..3885298c 100644 --- a/tests/manager/delegation/delegation_controller_test.py +++ b/tests/manager/delegation/delegation_controller_test.py @@ -5,6 +5,7 @@ from skale.contracts.manager.delegation.delegation_controller import FIELDS from skale.transactions.exceptions import ContractLogicError from skale.transactions.result import DryRunRevertError +from skale.types.delegation import DelegationStatus from skale.utils.contracts_provision.main import _skip_evm_time from skale.utils.contracts_provision.utils import generate_random_name @@ -34,7 +35,7 @@ def _delegate_and_activate(skale, validator_id=D_VALIDATOR_ID): delegations[-1]['id'], wait_for=True ) - _skip_evm_time(skale.web3, MONTH_IN_SECONDS, mine=False) + _skip_evm_time(skale.web3, MONTH_IN_SECONDS) def _get_number_of_delegations(skale, validator_id=D_VALIDATOR_ID): @@ -84,9 +85,6 @@ def test_delegate(skale, validator): ) assert delegations[-1]['info'] == D_DELEGATION_INFO - delegated_now_after = skale.delegation_controller.get_delegated_to_validator_now( - validator_id - ) delegated_now_after = skale.delegation_controller.get_delegated_to_validator_now( validator_id ) @@ -149,14 +147,14 @@ def test_accept_pending_delegation(skale, validator): validator_id=validator_id ) delegation_id = delegations[-1]['id'] - assert delegations[-1]['status'] == 'PROPOSED' + assert delegations[-1]['status'] == DelegationStatus.PROPOSED assert delegations[-1]['info'] == info skale.delegation_controller.accept_pending_delegation(delegation_id) delegations = skale.delegation_controller.get_all_delegations_by_validator( validator_id=validator_id ) assert delegations[-1]['id'] == delegation_id - assert delegations[-1]['status'] == 'ACCEPTED' + assert delegations[-1]['status'] == DelegationStatus.ACCEPTED assert delegations[-1]['info'] == info @@ -173,7 +171,7 @@ def test_cancel_pending_delegation(skale, validator): validator_id=validator_id ) delegation_id = delegations[-1]['id'] - assert delegations[-1]['status'] == 'PROPOSED' + assert delegations[-1]['status'] == DelegationStatus.PROPOSED skale.delegation_controller.cancel_pending_delegation( delegation_id, wait_for=True @@ -182,7 +180,7 @@ def test_cancel_pending_delegation(skale, validator): validator_id=validator_id ) assert delegations[-1]['id'] == delegation_id - assert delegations[-1]['status'] == 'CANCELED' + assert delegations[-1]['status'] == DelegationStatus.CANCELED def test_request_undelegate(skale, validator): @@ -225,4 +223,4 @@ def test_request_undelegate(skale, validator): validator_id=validator_id ) assert delegations[-1]['id'] == delegation_id - assert delegations[-1]['status'] == 'UNDELEGATION_REQUESTED' + assert delegations[-1]['status'] == DelegationStatus.UNDELEGATION_REQUESTED diff --git a/tests/manager/dkg_test.py b/tests/manager/dkg_test.py index 5bdcc18e..2fe51df1 100644 --- a/tests/manager/dkg_test.py +++ b/tests/manager/dkg_test.py @@ -4,6 +4,7 @@ from hexbytes import HexBytes from skale.contracts.manager.dkg import G2Point, KeyShare +from skale.utils.helper import split_public_key SCHAIN_NAME = 'pointed-asellus-australis' PUBLIC_KEY = '0xfcb3765bdb954ab0672fce731583ad8a94cf05fe63c147f881f8feea18e072d4cad3ec142a65de66a1d50e4fc34a7841c5488ccb55d02cf86013208c17517d64' # noqa @@ -68,10 +69,10 @@ def test_response(skale): group_index = skale.schains.name_to_id(SCHAIN_NAME) share = group_index # not an invariant, only a mock secret_number = 1 - multiplied_share = G2Point(1, 2, 3, 4).tuple - verification_vector = [G2Point(1, 2, 3, 4).tuple for i in range(0, 3)] - verification_vector_mult = [G2Point(1, 2, 3, 4).tuple for i in range(0, 3)] - secret_key_contribution = [KeyShare(PUBLIC_KEY, share).tuple] + multiplied_share = G2Point((1, 2), (3, 4)) + verification_vector = [G2Point((1, 2), (3, 4)) for i in range(0, 3)] + verification_vector_mult = [G2Point((1, 2), (3, 4)) for i in range(0, 3)] + secret_key_contribution = [KeyShare(split_public_key(PUBLIC_KEY), share)] exp = skale.web3.eth.account.sign_transaction( expected_txn, skale.wallet._private_key).rawTransaction diff --git a/tests/manager/manager_test.py b/tests/manager/manager_test.py index 2c7349e9..6ecdc5c7 100644 --- a/tests/manager/manager_test.py +++ b/tests/manager/manager_test.py @@ -69,7 +69,7 @@ def test_create_delete_schain(skale, nodes): schains_ids_after = skale.schains_internal.get_all_schains_ids() schains_names = [ - skale.schains.get(sid)['name'] + skale.schains.get(sid).name for sid in schains_ids_after ] assert name in schains_names @@ -81,7 +81,7 @@ def test_create_delete_schain(skale, nodes): schains_ids_after = skale.schains_internal.get_all_schains_ids() schains_names = [ - skale.schains.get(sid)['name'] + skale.schains.get(sid).name for sid in schains_ids_after ] assert name not in schains_names @@ -100,7 +100,7 @@ def test_delete_schain_by_root(skale, nodes): schains_ids_after = skale.schains_internal.get_all_schains_ids() schains_names = [ - skale.schains.get(sid)['name'] + skale.schains.get(sid).name for sid in schains_ids_after ] assert name not in schains_names @@ -117,7 +117,7 @@ def test_create_delete_default_schain(skale, nodes): schains_ids_after = skale.schains_internal.get_all_schains_ids() schains_names = [ - skale.schains.get(sid)['name'] + skale.schains.get(sid).name for sid in schains_ids_after ] assert name in schains_names @@ -129,7 +129,7 @@ def test_create_delete_default_schain(skale, nodes): schains_ids_after = skale.schains_internal.get_all_schains_ids() schains_names = [ - skale.schains.get(sid)['name'] + skale.schains.get(sid).name for sid in schains_ids_after ] assert name not in schains_names diff --git a/tests/manager/node_rotation_test.py b/tests/manager/node_rotation_test.py index 9aaa6c61..a2aa609f 100644 --- a/tests/manager/node_rotation_test.py +++ b/tests/manager/node_rotation_test.py @@ -12,16 +12,7 @@ def test_get_rotation(skale): - assert skale.node_rotation.get_rotation(DEFAULT_SCHAIN_NAME) == { - 'leaving_node': 0, - 'new_node': 0, - 'freeze_until': 0, - 'rotation_id': 0 - } - - -def test_get_rotation_obj(skale): - assert skale.node_rotation.get_rotation_obj(DEFAULT_SCHAIN_NAME) == Rotation( + assert skale.node_rotation.get_rotation(DEFAULT_SCHAIN_NAME) == Rotation( leaving_node_id=0, new_node_id=0, freeze_until=0, diff --git a/tests/manager/schains_internal_test.py b/tests/manager/schains_internal_test.py index 1abc952d..4dc519e6 100644 --- a/tests/manager/schains_internal_test.py +++ b/tests/manager/schains_internal_test.py @@ -1,5 +1,6 @@ """ SKALE chain internal test """ +from dataclasses import astuple, fields from skale.contracts.manager.schains import FIELDS from tests.constants import ( DEFAULT_SCHAIN_ID, @@ -9,14 +10,14 @@ def test_get_raw(skale): - schain_arr = skale.schains_internal.get_raw(DEFAULT_SCHAIN_ID) - assert len(FIELDS) == len(schain_arr) + 3 # +1 for chainId + options + schain = skale.schains_internal.get_raw(DEFAULT_SCHAIN_ID) + assert len(FIELDS) == len(fields(schain)) + 2 # +2 for chainId + options def test_get_raw_not_exist(skale): not_exist_schain_id = skale.schains.name_to_id('unused_hash') schain_arr = skale.schains_internal.get_raw(not_exist_schain_id) - assert schain_arr == EMPTY_SCHAIN_ARR + assert list(astuple(schain_arr)) == EMPTY_SCHAIN_ARR def test_get_schains_number(skale, schain): @@ -36,7 +37,7 @@ def test_get_schain_id_by_index_for_owner(skale, schain): skale.wallet.address, 0 ) schain = skale.schains.get(schain_id) - assert schain['mainnetOwner'] == skale.wallet.address + assert schain.mainnetOwner == skale.wallet.address def test_get_node_ids_for_schain(skale, schain): diff --git a/tests/manager/schains_test.py b/tests/manager/schains_test.py index 4b2dd78a..d62c2d30 100644 --- a/tests/manager/schains_test.py +++ b/tests/manager/schains_test.py @@ -1,6 +1,8 @@ """ SKALE chain test """ +from dataclasses import fields from hexbytes import HexBytes +from web3 import Web3 from skale.contracts.manager.schains import FIELDS, SchainStructure from skale.dataclasses.schain_options import SchainOptions @@ -15,22 +17,16 @@ def test_get(skale): schain = skale.schains.get(DEFAULT_SCHAIN_ID) - assert list(schain.keys()) == FIELDS - assert [k for k, v in schain.items() if v is None] == [] - - -def test_get_object(skale): - schain = skale.schains.get(DEFAULT_SCHAIN_ID, obj=True) assert isinstance(schain, SchainStructure) assert isinstance(schain.options, SchainOptions) def test_get_by_name(skale): schain = skale.schains.get(DEFAULT_SCHAIN_ID) - schain_name = schain['name'] + schain_name = schain.name schain_by_name = skale.schains.get_by_name(schain_name) - assert list(schain_by_name.keys()) == FIELDS + assert [field.name for field in fields(schain_by_name)] == FIELDS assert schain == schain_by_name @@ -38,7 +34,7 @@ def test_get_schains_for_owner(skale, schain, empty_account): schains = skale.schains.get_schains_for_owner(skale.wallet.address) assert isinstance(schains, list) assert len(schains) > 0 - assert set(schains[-1].keys()) == set(FIELDS) + assert set([field.name for field in fields(schains[-1])]) == set(FIELDS) schains = skale.schains.get_schains_for_owner(empty_account.address) assert schains == [] @@ -57,7 +53,7 @@ def test_get_schains_for_node(skale, schain): test_schain = schains_for_node[0] schain_node_ids = skale.schains_internal.get_node_ids_for_schain( - test_schain['name'] + test_schain.name ) assert node_id in schain_node_ids @@ -65,13 +61,13 @@ def test_get_schains_for_node(skale, schain): def test_name_to_id(skale): schain_id = skale.schains.name_to_id(DEFAULT_SCHAIN_NAME) - assert schain_id == DEFAULT_SCHAIN_ID + assert schain_id == Web3.to_bytes(hexstr=DEFAULT_SCHAIN_ID) def test_get_all_schains_ids(skale, schain): schains_ids = skale.schains_internal.get_all_schains_ids() schain = skale.schains.get(schains_ids[-1]) - assert list(schain.keys()) == FIELDS + assert [field.name for field in fields(schain)] == FIELDS def test_get_schain_price(skale): @@ -93,19 +89,19 @@ def test_add_schain_by_foundation(skale, nodes): ) schains_ids_after = skale.schains_internal.get_all_schains_ids() schains_names = [ - skale.schains.get(sid)['name'] + skale.schains.get(sid).name for sid in schains_ids_after ] assert name in schains_names new_schain = skale.schains.get_by_name(name) - assert new_schain['mainnetOwner'] == skale.wallet.address + assert new_schain.mainnetOwner == skale.wallet.address finally: skale.manager.delete_schain(name, wait_for=True) schains_ids_after = skale.schains_internal.get_all_schains_ids() schains_names = [ - skale.schains.get(sid)['name'] + skale.schains.get(sid).name for sid in schains_ids_after ] assert name not in schains_names @@ -130,7 +126,7 @@ def test_add_schain_by_foundation_with_options(skale, nodes): ), wait_for=True ) - schain = skale.schains.get_by_name(name, obj=True) + schain = skale.schains.get_by_name(name) assert schain.options.multitransaction_mode is True assert schain.options.threshold_encryption is False @@ -158,8 +154,8 @@ def test_add_schain_by_foundation_custom_owner(skale, nodes): ) new_schain = skale.schains.get_by_name(name) - assert new_schain['mainnetOwner'] != skale.wallet.address - assert new_schain['mainnetOwner'] == custom_wallet.address + assert new_schain.mainnetOwner != skale.wallet.address + assert new_schain.mainnetOwner == custom_wallet.address skale.wallet = custom_wallet finally: @@ -169,7 +165,7 @@ def test_add_schain_by_foundation_custom_owner(skale, nodes): schains_ids_after = skale.schains_internal.get_all_schains_ids() schains_names = [ - skale.schains.get(sid)['name'] + skale.schains.get(sid).name for sid in schains_ids_after ] assert name not in schains_names @@ -198,8 +194,8 @@ def test_add_schain_by_foundation_custom_originator(skale, nodes): ) new_schain = skale.schains.get_by_name(name) - assert new_schain['originator'] != skale.wallet.address - assert new_schain['originator'] == custom_originator.address + assert new_schain.originator != skale.wallet.address + assert new_schain.originator == custom_originator.address finally: if name: @@ -208,7 +204,7 @@ def test_add_schain_by_foundation_custom_originator(skale, nodes): schains_ids_after = skale.schains_internal.get_all_schains_ids() schains_names = [ - skale.schains.get(sid)['name'] + skale.schains.get(sid).name for sid in schains_ids_after ] assert name not in schains_names @@ -222,7 +218,7 @@ def test_get_active_schains_for_node(skale, nodes, schain): active_schains = skale.schains.get_active_schains_for_node(node_id) all_schains = skale.schains.get_schains_for_node(node_id) all_active_schains = [ - schain for schain in all_schains if schain['active']] + schain for schain in all_schains if schain.active] for active_schain in all_active_schains: assert active_schain in active_schains finally: diff --git a/tests/rotation_history/rotation_history_test.py b/tests/rotation_history/rotation_history_test.py index a3c57725..4d39610b 100644 --- a/tests/rotation_history/rotation_history_test.py +++ b/tests/rotation_history/rotation_history_test.py @@ -237,7 +237,7 @@ def test_get_new_nodes_list(skale, four_node_schain): rotation_id=1 ) - rotation = skale.node_rotation.get_rotation_obj(name) + rotation = skale.node_rotation.get_rotation(name) node_groups = get_previous_schain_groups( skale=skale, schain_name=name, @@ -251,7 +251,7 @@ def test_get_new_nodes_list(skale, four_node_schain): exiting_node_index = 3 rotate_node(skale, group_index, nodes, skale_instances, exiting_node_index, rotation_id=4) - rotation = skale.node_rotation.get_rotation_obj(name) + rotation = skale.node_rotation.get_rotation(name) node_groups = get_previous_schain_groups( skale=skale, schain_name=name, diff --git a/tests/rotation_history/utils.py b/tests/rotation_history/utils.py index 0091dd13..2791b9b9 100644 --- a/tests/rotation_history/utils.py +++ b/tests/rotation_history/utils.py @@ -4,6 +4,7 @@ import json import logging +from skale.types.dkg import Fp2Point, G2Point, KeyShare from skale.utils.contracts_provision.main import _skip_evm_time from skale.utils.contracts_provision import DEFAULT_DOMAIN_NAME @@ -110,11 +111,21 @@ def init_skale_from_wallet(wallet) -> Skale: def send_broadcasts(nodes, skale_instances, group_index, skip_node_index=None, rotation_id=0): for i, node in enumerate(nodes): if i != skip_node_index: + verification_vector = [ + G2Point(*[ + Fp2Point(*fp2_point) for fp2_point in g2_point + ]) + for g2_point in TEST_DKG_DATA['test_verification_vectors'][i] + ] + secret_key_contribution = [ + KeyShare(tuple(key_share[0]), key_share[1]) + for key_share in TEST_DKG_DATA['test_encrypted_secret_key_contributions'][i] + ] skale_instances[i].dkg.broadcast( group_index, node['node_id'], - TEST_DKG_DATA['test_verification_vectors'][i], - TEST_DKG_DATA['test_encrypted_secret_key_contributions'][i], + verification_vector, + secret_key_contribution, rotation_id ) else: diff --git a/tests/schain_config/generator_test.py b/tests/schain_config/generator_test.py index 4cfefb9b..aaebd186 100644 --- a/tests/schain_config/generator_test.py +++ b/tests/schain_config/generator_test.py @@ -10,11 +10,11 @@ def test_get_nodes_for_schain(skale, schain): fields_with_id.append('id') assert len(schain_nodes) >= MIN_NODES_IN_SCHAIN - assert list(schain_nodes[0].keys()) == fields_with_id + assert set(schain_nodes[0].keys()) == set(fields_with_id) def test_get_schain_nodes_with_schains(skale, schain): schain_name = schain nodes_with_schains = get_schain_nodes_with_schains(skale, schain_name) assert isinstance(nodes_with_schains[0]['schains'], list) - assert isinstance(nodes_with_schains[0]['schains'][0]['mainnetOwner'], str) + assert isinstance(nodes_with_schains[0]['schains'][0].mainnetOwner, str) diff --git a/tests/wallets/redis_adapter_test.py b/tests/wallets/redis_adapter_test.py index 4d032a13..fb954ef6 100644 --- a/tests/wallets/redis_adapter_test.py +++ b/tests/wallets/redis_adapter_test.py @@ -3,6 +3,7 @@ from unittest import mock import pytest from freezegun import freeze_time +from web3 import Web3 from skale.wallets.redis_wallet import ( RedisWalletNotSentError, @@ -51,7 +52,7 @@ def test_make_record(): score = '51623233060' tx_id, r = RedisWalletAdapter._make_record(tx, score, 2, method='createNode') assert tx_id.startswith(b'tx-') and len(tx_id) == 19 - assert r == b'{"status": "PROPOSED", "score": "51623233060", "multiplier": 2, "tx_hash": null, "method": "createNode", "meta": null, "from": "0x1", "to": "0x2", "value": 1, "gasPrice": 1, "gas": null, "nonce": 1, "chainId": 1}' # noqa + assert r == b'{"status": "PROPOSED", "score": "51623233060", "multiplier": 2, "tx_hash": null, "method": "createNode", "from": "0x1", "to": "0x2", "value": 1, "gasPrice": 1, "gas": null, "nonce": 1, "chainId": 1}' # noqa def test_sign_and_send(rdp): @@ -64,8 +65,8 @@ def test_sign_and_send(rdp): 'nonce': 1, 'chainId': 1 } - tx_id = rdp.sign_and_send(tx, multiplier=2, priority=5) - assert tx_id.startswith('tx-') and len(tx_id) == 19 + tx_id = Web3.to_bytes(hexstr=rdp.sign_and_send(tx, multiplier=2, priority=5)) + assert tx_id.startswith(b'tx-') and len(tx_id) == 19 rdp.rs.pipeline = mock.Mock(side_effect=RedisTestError('rtest')) with pytest.raises(RedisWalletNotSentError): diff --git a/tests/wallets/sgx_test.py b/tests/wallets/sgx_test.py index b1fe7132..71dcc8b8 100644 --- a/tests/wallets/sgx_test.py +++ b/tests/wallets/sgx_test.py @@ -44,6 +44,7 @@ def test_sgx_sign(wallet): def test_sgx_sign_and_send_without_nonce(wallet): send_tx_mock = mock.Mock() + send_tx_mock.return_value = HexBytes('') wallet._web3.eth.send_raw_transaction = send_tx_mock wallet._web3.eth.get_transaction_count = mock.Mock(return_value=0) tx_dict = {