From d75cae70aefaed19053b0b1fa75761da7cae438e Mon Sep 17 00:00:00 2001 From: Erik van den Brink Date: Mon, 18 Jan 2021 14:37:41 +0100 Subject: [PATCH] https://github.com/neo-project/neo/pull/1933 --- neo3/contracts/__init__.py | 6 +- neo3/contracts/applicationengine.py | 16 ++++- neo3/contracts/interop/contract.py | 61 ++++++++++++++----- .../interop/test_contract_interop.py | 4 +- 4 files changed, 65 insertions(+), 22 deletions(-) diff --git a/neo3/contracts/__init__.py b/neo3/contracts/__init__.py index daf6a1fe..5fd7e7c8 100644 --- a/neo3/contracts/__init__.py +++ b/neo3/contracts/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations import hashlib +from enum import IntEnum from .contracttypes import (TriggerType) from .descriptor import (ContractPermissionDescriptor) from .manifest import (ContractGroup, @@ -18,7 +19,7 @@ from .binaryserializer import BinarySerializer from .jsonserializer import (NEOJson, JSONSerializer) from .native import (CallFlags, NativeContract, PolicyContract, NeoToken, GasToken, OracleContract) -from .applicationengine import ApplicationEngine +from .applicationengine import ApplicationEngine, CheckReturnType def syscall_name_to_int(name: str) -> int: @@ -38,4 +39,5 @@ def syscall_name_to_int(name: str) -> int: 'CallFlags', 'PolicyContract', 'ApplicationEngine', - 'syscall_name_to_int'] + 'syscall_name_to_int', + 'CheckReturnType'] diff --git a/neo3/contracts/applicationengine.py b/neo3/contracts/applicationengine.py index 362c7e9b..9ec36624 100644 --- a/neo3/contracts/applicationengine.py +++ b/neo3/contracts/applicationengine.py @@ -9,6 +9,12 @@ from contextlib import suppress +class CheckReturnType(enum.IntEnum): + NONE = 0 + ENSURE_IS_EMPTY = 1 + ENSURE_NOT_EMPTY = 2 + + class ApplicationEngine(vm.ApplicationEngineCpp): _interop_calls: Dict[int, interop.InteropDescriptor] = {} _invocation_states: Dict[vm.ExecutionContext, InvocationState] = {} @@ -27,7 +33,7 @@ class ApplicationEngine(vm.ApplicationEngineCpp): class InvocationState: return_type: type = None # type: ignore callback: Optional[Callable] = None - check_return_value: bool = False + check_return_value: contracts.CheckReturnType = CheckReturnType.NONE def __init__(self, trigger: contracts.TriggerType, @@ -303,7 +309,10 @@ def context_unloaded(self, context: vm.ExecutionContext): state = self._invocation_states.pop(self.current_context) except KeyError: return - if state.check_return_value: + if state.check_return_value == contracts.CheckReturnType.ENSURE_IS_EMPTY: + if len(context.evaluation_stack) != 0: + raise ValueError("Evaluation expected to be empty, but was not") + elif state.check_return_value == contracts.CheckReturnType.ENSURE_NOT_EMPTY: eval_stack_len = len(context.evaluation_stack) if eval_stack_len == 0: self.push(vm.NullStackItem()) @@ -316,7 +325,8 @@ def context_unloaded(self, context: vm.ExecutionContext): def load_context(self, context: vm.ExecutionContext, check_return_value: bool = False): if check_return_value: - self._get_invocation_state(self.current_context).check_return_value = True + self._get_invocation_state(self.current_context).check_return_value = \ + contracts.CheckReturnType.ENSURE_NOT_EMPTY super(ApplicationEngine, self).load_context(context) def load_script_with_callflags(self, script, call_flags: contracts.native.CallFlags, initial_position=0): diff --git a/neo3/contracts/interop/contract.py b/neo3/contracts/interop/contract.py index 820ff602..ea2b64cf 100644 --- a/neo3/contracts/interop/contract.py +++ b/neo3/contracts/interop/contract.py @@ -7,7 +7,7 @@ @register("System.Contract.Create", 0, contracts.native.CallFlags.ALLOW_MODIFIED_STATES, False, [bytes, bytes]) -def contract_create(engine: contracts.ApplicationEngine, script: bytes, manifest: bytes) -> storage.ContractState: +def contract_create(engine: contracts.ApplicationEngine, script: bytes, manifest: bytes) -> None: script_len = len(script) manifest_len = len(manifest) if (script_len == 0 @@ -28,7 +28,17 @@ def contract_create(engine: contracts.ApplicationEngine, script: bytes, manifest raise ValueError("Error: manifest does not match with script") engine.snapshot.contracts.put(contract) - return contract + + engine.push(engine._native_to_stackitem(contract, storage.ContractState)) + method_descriptor = contract.manifest.abi.get_method("_deploy") + if method_descriptor is not None: + contract_call_internal_ex(engine, + contract, + method_descriptor, + vm.ArrayStackItem(engine.reference_counter, vm.BooleanStackItem(False)), + contracts.native.CallFlags.ALL, + contracts.CheckReturnType.ENSURE_IS_EMPTY + ) @register("System.Contract.Update", 0, contracts.native.CallFlags.ALLOW_MODIFIED_STATES, False, [bytes, bytes]) @@ -81,6 +91,17 @@ def contract_update(engine: contracts.ApplicationEngine, script: bytes, manifest and len(list(engine.snapshot.storages.find(contract.script_hash(), key_prefix=b''))) != 0): raise ValueError("Error: New contract does not support storage while old contract has existing storage") + if len(script) != 0: + method_descriptor = contract.manifest.abi.get_method("_deploy") + if method_descriptor is not None: + contract_call_internal_ex(engine, + contract, + method_descriptor, + vm.ArrayStackItem(engine.reference_counter, vm.BooleanStackItem(True)), + contracts.native.CallFlags.ALL, + contracts.CheckReturnType.ENSURE_IS_EMPTY + ) + @register("System.Contract.Destroy", 1000000, contracts.native.CallFlags.ALLOW_MODIFIED_STATES, False) def contract_destroy(engine: contracts.ApplicationEngine) -> None: @@ -109,43 +130,53 @@ def contract_call_internal(engine: contracts.ApplicationEngine, if target_contract is None: raise ValueError("[System.Contract.Call] Can't find target contract") + method_descriptor = target_contract.manifest.abi.get_method(method) + if method_descriptor is None: + raise ValueError(f"[System.Contract.Call] Method '{method}' does not exist on target contract") + current_contract = engine.snapshot.contracts.try_get(engine.current_scripthash, read_only=True) if current_contract and not current_contract.manifest.can_call(target_contract.manifest, method): raise ValueError(f"[System.Contract.Call] Not allowed to call target method '{method}' according to manifest") - counter = engine._invocation_counter.get(target_contract.script_hash(), 0) - engine._invocation_counter.update({target_contract.script_hash(): counter + 1}) + contract_call_internal_ex(engine, target_contract, method_descriptor, args, flags, + contracts.CheckReturnType.ENSURE_NOT_EMPTY) - engine._get_invocation_state(engine.current_context).check_return_value = True + +def contract_call_internal_ex(engine: contracts.ApplicationEngine, + contract: storage.ContractState, + contract_method_descriptor: contracts.ContractMethodDescriptor, + args: vm.ArrayStackItem, + flags: contracts.native.CallFlags, + check_return_value: contracts.CheckReturnType) -> None: + counter = engine._invocation_counter.get(contract.script_hash(), 0) + engine._invocation_counter.update({contract.script_hash(): counter + 1}) + + engine._get_invocation_state(engine.current_context).check_return_value = check_return_value state = engine.current_context calling_flags = state.call_flags - contract_method_descriptor = target_contract.manifest.abi.get_method(method) - if contract_method_descriptor is None: - raise ValueError(f"[System.Contract.Call] requested target method '{method}' does not exist on target contract") - arg_len = len(args) expected_len = len(contract_method_descriptor.parameters) if arg_len != expected_len: raise ValueError( f"[System.Contract.Call] Invalid number of contract arguments. Expected {expected_len} actual {arg_len}") # noqa - context_new = engine.load_script(vm.Script(target_contract.script), contract_method_descriptor.offset) + context_new = engine.load_script(vm.Script(contract.script), contract_method_descriptor.offset) context_new.calling_script = state.script context_new.call_flags = flags & calling_flags - if contracts.NativeContract.is_native(contract_hash): + if contracts.NativeContract.is_native(contract.script_hash()): context_new.evaluation_stack.push(args) - context_new.evaluation_stack.push(vm.ByteStringStackItem(method.encode('utf-8'))) + context_new.evaluation_stack.push(vm.ByteStringStackItem(contract_method_descriptor.name.encode('utf-8'))) else: for item in reversed(args): context_new.evaluation_stack.push(item) context_new.ip = contract_method_descriptor.offset - contract_method_descriptor = target_contract.manifest.abi.get_method("_initialize") - if contract_method_descriptor is not None: - engine.load_cloned_context(contract_method_descriptor.offset) + method_descriptor = contract.manifest.abi.get_method("_initialize") + if method_descriptor is not None: + engine.load_cloned_context(method_descriptor.offset) @register("System.Contract.Call", 1000000, contracts.native.CallFlags.ALLOW_CALL, False, diff --git a/tests/contracts/interop/test_contract_interop.py b/tests/contracts/interop/test_contract_interop.py index de495902..d645498b 100644 --- a/tests/contracts/interop/test_contract_interop.py +++ b/tests/contracts/interop/test_contract_interop.py @@ -437,7 +437,7 @@ def test_contract_call_exceptions(self): engine.snapshot.contracts.put(new_current_contract) with self.assertRaises(ValueError) as context: contract_call_internal(engine, target_contract.script_hash(), "invalid_method", vm.ArrayStackItem(engine.reference_counter), contracts.native.CallFlags) - self.assertEqual("[System.Contract.Call] Not allowed to call target method 'invalid_method' according to manifest", str(context.exception)) + self.assertEqual("[System.Contract.Call] Method 'invalid_method' does not exist on target contract", str(context.exception)) # restore current contract to its original form and try to call a non-existing contract current_contract = storage.ContractState(hello_world_nef.script, hello_world_manifest) @@ -446,7 +446,7 @@ def test_contract_call_exceptions(self): with self.assertRaises(ValueError) as context: contract_call_internal(engine, target_contract.script_hash(), "invalid_method", vm.ArrayStackItem(engine.reference_counter), contracts.native.CallFlags) - self.assertEqual("[System.Contract.Call] requested target method 'invalid_method' does not exist on target contract", str(context.exception)) + self.assertEqual("[System.Contract.Call] Method 'invalid_method' does not exist on target contract", str(context.exception)) # call the target method with invalid number of arguments array = vm.ArrayStackItem(engine.reference_counter)