Skip to content

Commit

Permalink
https://github.com/neo-project/neo/pull/1933
Browse files Browse the repository at this point in the history
  • Loading branch information
ixje committed Jan 18, 2021
1 parent 61a87b3 commit d75cae7
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 22 deletions.
6 changes: 4 additions & 2 deletions neo3/contracts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import hashlib
from enum import IntEnum
from .contracttypes import (TriggerType)
from .descriptor import (ContractPermissionDescriptor)
from .manifest import (ContractGroup,
Expand All @@ -18,7 +19,7 @@
from .binaryserializer import BinarySerializer
from .jsonserializer import (NEOJson, JSONSerializer)
from .native import (CallFlags, NativeContract, PolicyContract, NeoToken, GasToken, OracleContract)
from .applicationengine import ApplicationEngine
from .applicationengine import ApplicationEngine, CheckReturnType


def syscall_name_to_int(name: str) -> int:
Expand All @@ -38,4 +39,5 @@ def syscall_name_to_int(name: str) -> int:
'CallFlags',
'PolicyContract',
'ApplicationEngine',
'syscall_name_to_int']
'syscall_name_to_int',
'CheckReturnType']
16 changes: 13 additions & 3 deletions neo3/contracts/applicationengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from contextlib import suppress


class CheckReturnType(enum.IntEnum):
NONE = 0
ENSURE_IS_EMPTY = 1
ENSURE_NOT_EMPTY = 2


class ApplicationEngine(vm.ApplicationEngineCpp):
_interop_calls: Dict[int, interop.InteropDescriptor] = {}
_invocation_states: Dict[vm.ExecutionContext, InvocationState] = {}
Expand All @@ -27,7 +33,7 @@ class ApplicationEngine(vm.ApplicationEngineCpp):
class InvocationState:
return_type: type = None # type: ignore
callback: Optional[Callable] = None
check_return_value: bool = False
check_return_value: contracts.CheckReturnType = CheckReturnType.NONE

def __init__(self,
trigger: contracts.TriggerType,
Expand Down Expand Up @@ -303,7 +309,10 @@ def context_unloaded(self, context: vm.ExecutionContext):
state = self._invocation_states.pop(self.current_context)
except KeyError:
return
if state.check_return_value:
if state.check_return_value == contracts.CheckReturnType.ENSURE_IS_EMPTY:
if len(context.evaluation_stack) != 0:
raise ValueError("Evaluation expected to be empty, but was not")
elif state.check_return_value == contracts.CheckReturnType.ENSURE_NOT_EMPTY:
eval_stack_len = len(context.evaluation_stack)
if eval_stack_len == 0:
self.push(vm.NullStackItem())
Expand All @@ -316,7 +325,8 @@ def context_unloaded(self, context: vm.ExecutionContext):

def load_context(self, context: vm.ExecutionContext, check_return_value: bool = False):
if check_return_value:
self._get_invocation_state(self.current_context).check_return_value = True
self._get_invocation_state(self.current_context).check_return_value = \
contracts.CheckReturnType.ENSURE_NOT_EMPTY
super(ApplicationEngine, self).load_context(context)

def load_script_with_callflags(self, script, call_flags: contracts.native.CallFlags, initial_position=0):
Expand Down
61 changes: 46 additions & 15 deletions neo3/contracts/interop/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@register("System.Contract.Create", 0, contracts.native.CallFlags.ALLOW_MODIFIED_STATES, False, [bytes, bytes])
def contract_create(engine: contracts.ApplicationEngine, script: bytes, manifest: bytes) -> storage.ContractState:
def contract_create(engine: contracts.ApplicationEngine, script: bytes, manifest: bytes) -> None:
script_len = len(script)
manifest_len = len(manifest)
if (script_len == 0
Expand All @@ -28,7 +28,17 @@ def contract_create(engine: contracts.ApplicationEngine, script: bytes, manifest
raise ValueError("Error: manifest does not match with script")

engine.snapshot.contracts.put(contract)
return contract

engine.push(engine._native_to_stackitem(contract, storage.ContractState))
method_descriptor = contract.manifest.abi.get_method("_deploy")
if method_descriptor is not None:
contract_call_internal_ex(engine,
contract,
method_descriptor,
vm.ArrayStackItem(engine.reference_counter, vm.BooleanStackItem(False)),
contracts.native.CallFlags.ALL,
contracts.CheckReturnType.ENSURE_IS_EMPTY
)


@register("System.Contract.Update", 0, contracts.native.CallFlags.ALLOW_MODIFIED_STATES, False, [bytes, bytes])
Expand Down Expand Up @@ -81,6 +91,17 @@ def contract_update(engine: contracts.ApplicationEngine, script: bytes, manifest
and len(list(engine.snapshot.storages.find(contract.script_hash(), key_prefix=b''))) != 0):
raise ValueError("Error: New contract does not support storage while old contract has existing storage")

if len(script) != 0:
method_descriptor = contract.manifest.abi.get_method("_deploy")
if method_descriptor is not None:
contract_call_internal_ex(engine,
contract,
method_descriptor,
vm.ArrayStackItem(engine.reference_counter, vm.BooleanStackItem(True)),
contracts.native.CallFlags.ALL,
contracts.CheckReturnType.ENSURE_IS_EMPTY
)


@register("System.Contract.Destroy", 1000000, contracts.native.CallFlags.ALLOW_MODIFIED_STATES, False)
def contract_destroy(engine: contracts.ApplicationEngine) -> None:
Expand Down Expand Up @@ -109,43 +130,53 @@ def contract_call_internal(engine: contracts.ApplicationEngine,
if target_contract is None:
raise ValueError("[System.Contract.Call] Can't find target contract")

method_descriptor = target_contract.manifest.abi.get_method(method)
if method_descriptor is None:
raise ValueError(f"[System.Contract.Call] Method '{method}' does not exist on target contract")

current_contract = engine.snapshot.contracts.try_get(engine.current_scripthash, read_only=True)
if current_contract and not current_contract.manifest.can_call(target_contract.manifest, method):
raise ValueError(f"[System.Contract.Call] Not allowed to call target method '{method}' according to manifest")

counter = engine._invocation_counter.get(target_contract.script_hash(), 0)
engine._invocation_counter.update({target_contract.script_hash(): counter + 1})
contract_call_internal_ex(engine, target_contract, method_descriptor, args, flags,
contracts.CheckReturnType.ENSURE_NOT_EMPTY)

engine._get_invocation_state(engine.current_context).check_return_value = True

def contract_call_internal_ex(engine: contracts.ApplicationEngine,
contract: storage.ContractState,
contract_method_descriptor: contracts.ContractMethodDescriptor,
args: vm.ArrayStackItem,
flags: contracts.native.CallFlags,
check_return_value: contracts.CheckReturnType) -> None:
counter = engine._invocation_counter.get(contract.script_hash(), 0)
engine._invocation_counter.update({contract.script_hash(): counter + 1})

engine._get_invocation_state(engine.current_context).check_return_value = check_return_value

state = engine.current_context
calling_flags = state.call_flags

contract_method_descriptor = target_contract.manifest.abi.get_method(method)
if contract_method_descriptor is None:
raise ValueError(f"[System.Contract.Call] requested target method '{method}' does not exist on target contract")

arg_len = len(args)
expected_len = len(contract_method_descriptor.parameters)
if arg_len != expected_len:
raise ValueError(
f"[System.Contract.Call] Invalid number of contract arguments. Expected {expected_len} actual {arg_len}") # noqa

context_new = engine.load_script(vm.Script(target_contract.script), contract_method_descriptor.offset)
context_new = engine.load_script(vm.Script(contract.script), contract_method_descriptor.offset)
context_new.calling_script = state.script
context_new.call_flags = flags & calling_flags

if contracts.NativeContract.is_native(contract_hash):
if contracts.NativeContract.is_native(contract.script_hash()):
context_new.evaluation_stack.push(args)
context_new.evaluation_stack.push(vm.ByteStringStackItem(method.encode('utf-8')))
context_new.evaluation_stack.push(vm.ByteStringStackItem(contract_method_descriptor.name.encode('utf-8')))
else:
for item in reversed(args):
context_new.evaluation_stack.push(item)
context_new.ip = contract_method_descriptor.offset

contract_method_descriptor = target_contract.manifest.abi.get_method("_initialize")
if contract_method_descriptor is not None:
engine.load_cloned_context(contract_method_descriptor.offset)
method_descriptor = contract.manifest.abi.get_method("_initialize")
if method_descriptor is not None:
engine.load_cloned_context(method_descriptor.offset)


@register("System.Contract.Call", 1000000, contracts.native.CallFlags.ALLOW_CALL, False,
Expand Down
4 changes: 2 additions & 2 deletions tests/contracts/interop/test_contract_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def test_contract_call_exceptions(self):
engine.snapshot.contracts.put(new_current_contract)
with self.assertRaises(ValueError) as context:
contract_call_internal(engine, target_contract.script_hash(), "invalid_method", vm.ArrayStackItem(engine.reference_counter), contracts.native.CallFlags)
self.assertEqual("[System.Contract.Call] Not allowed to call target method 'invalid_method' according to manifest", str(context.exception))
self.assertEqual("[System.Contract.Call] Method 'invalid_method' does not exist on target contract", str(context.exception))

# restore current contract to its original form and try to call a non-existing contract
current_contract = storage.ContractState(hello_world_nef.script, hello_world_manifest)
Expand All @@ -446,7 +446,7 @@ def test_contract_call_exceptions(self):

with self.assertRaises(ValueError) as context:
contract_call_internal(engine, target_contract.script_hash(), "invalid_method", vm.ArrayStackItem(engine.reference_counter), contracts.native.CallFlags)
self.assertEqual("[System.Contract.Call] requested target method 'invalid_method' does not exist on target contract", str(context.exception))
self.assertEqual("[System.Contract.Call] Method 'invalid_method' does not exist on target contract", str(context.exception))

# call the target method with invalid number of arguments
array = vm.ArrayStackItem(engine.reference_counter)
Expand Down

0 comments on commit d75cae7

Please sign in to comment.