diff --git a/src/neo/SmartContract/ApplicationEngine.Contract.cs b/src/neo/SmartContract/ApplicationEngine.Contract.cs index a0c8cb19ca..8b74bed9ea 100644 --- a/src/neo/SmartContract/ApplicationEngine.Contract.cs +++ b/src/neo/SmartContract/ApplicationEngine.Contract.cs @@ -28,7 +28,7 @@ partial class ApplicationEngine /// public static readonly InteropDescriptor System_Contract_CreateStandardAccount = Register("System.Contract.CreateStandardAccount", nameof(CreateStandardAccount), 0_00010000, CallFlags.None, true); - protected internal ContractState CreateContract(byte[] script, byte[] manifest) + protected internal void CreateContract(byte[] script, byte[] manifest) { if (script.Length == 0 || script.Length > MaxContractLength) throw new ArgumentException($"Invalid Script Length: {script.Length}"); @@ -50,7 +50,16 @@ protected internal ContractState CreateContract(byte[] script, byte[] manifest) if (!contract.Manifest.IsValid(hash)) throw new InvalidOperationException($"Invalid Manifest Hash: {hash}"); Snapshot.Contracts.Add(hash, contract); - return contract; + + // We should push it onto the caller's stack. + + Push(Convert(contract)); + + // Execute _deploy + + ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod("_deploy"); + if (md != null) + CallContractInternal(contract, md, new Array(ReferenceCounter) { false }, CallFlags.All, CheckReturnType.EnsureIsEmpty); } protected internal void UpdateContract(byte[] script, byte[] manifest) @@ -90,6 +99,12 @@ protected internal void UpdateContract(byte[] script, byte[] manifest) if (!contract.HasStorage && Snapshot.Storages.Find(BitConverter.GetBytes(contract.Id)).Any()) throw new InvalidOperationException($"Contract Does Not Support Storage But Uses Storage"); } + if (script != null) + { + ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod("_deploy"); + if (md != null) + CallContractInternal(contract, md, new Array(ReferenceCounter) { true }, CallFlags.All, CheckReturnType.EnsureIsEmpty); + } } protected internal void DestroyContract() @@ -121,12 +136,18 @@ private void CallContractInternal(UInt160 contractHash, string method, Array arg ContractState contract = Snapshot.Contracts.TryGet(contractHash); if (contract is null) throw new InvalidOperationException($"Called Contract Does Not Exist: {contractHash}"); + ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod(method); + if (md is null) throw new InvalidOperationException($"Method {method} Does Not Exist In Contract {contractHash}"); ContractManifest currentManifest = Snapshot.Contracts.TryGet(CurrentScriptHash)?.Manifest; - if (currentManifest != null && !currentManifest.CanCall(contract.Manifest, method)) throw new InvalidOperationException($"Cannot Call Method {method} Of Contract {contractHash} From Contract {CurrentScriptHash}"); + CallContractInternal(contract, md, args, flags, CheckReturnType.EnsureNotEmpty); + } + + private void CallContractInternal(ContractState contract, ContractMethodDescriptor method, Array args, CallFlags flags, CheckReturnType checkReturnValue) + { if (invocationCounter.TryGetValue(contract.ScriptHash, out var counter)) { invocationCounter[contract.ScriptHash] = counter + 1; @@ -136,24 +157,22 @@ private void CallContractInternal(UInt160 contractHash, string method, Array arg invocationCounter[contract.ScriptHash] = 1; } - GetInvocationState(CurrentContext).NeedCheckReturnValue = true; + GetInvocationState(CurrentContext).NeedCheckReturnValue = checkReturnValue; ExecutionContextState state = CurrentContext.GetState(); UInt160 callingScriptHash = state.ScriptHash; CallFlags callingFlags = state.CallFlags; - ContractMethodDescriptor md = contract.Manifest.Abi.GetMethod(method); - if (md is null) throw new InvalidOperationException($"Method {method} Does Not Exist In Contract {contractHash}"); - if (args.Count != md.Parameters.Length) throw new InvalidOperationException($"Method {method} Expects {md.Parameters.Length} Arguments But Receives {args.Count} Arguments"); - ExecutionContext context_new = LoadScript(contract.Script, md.Offset); + if (args.Count != method.Parameters.Length) throw new InvalidOperationException($"Method {method.Name} Expects {method.Parameters.Length} Arguments But Receives {args.Count} Arguments"); + ExecutionContext context_new = LoadScript(contract.Script, method.Offset); state = context_new.GetState(); state.CallingScriptHash = callingScriptHash; state.CallFlags = flags & callingFlags; - if (NativeContract.IsNative(contractHash)) + if (NativeContract.IsNative(contract.ScriptHash)) { context_new.EvaluationStack.Push(args); - context_new.EvaluationStack.Push(method); + context_new.EvaluationStack.Push(method.Name); } else { @@ -161,8 +180,8 @@ private void CallContractInternal(UInt160 contractHash, string method, Array arg context_new.EvaluationStack.Push(args[i]); } - md = contract.Manifest.Abi.GetMethod("_initialize"); - if (md != null) LoadContext(context_new.Clone(md.Offset)); + method = contract.Manifest.Abi.GetMethod("_initialize"); + if (method != null) LoadContext(context_new.Clone(method.Offset)); } protected internal bool IsStandardContract(UInt160 hash) diff --git a/src/neo/SmartContract/ApplicationEngine.cs b/src/neo/SmartContract/ApplicationEngine.cs index cd114e9bb7..72a79e8fd0 100644 --- a/src/neo/SmartContract/ApplicationEngine.cs +++ b/src/neo/SmartContract/ApplicationEngine.cs @@ -18,11 +18,18 @@ namespace Neo.SmartContract { public partial class ApplicationEngine : ExecutionEngine { + private enum CheckReturnType : byte + { + None = 0, + EnsureIsEmpty = 1, + EnsureNotEmpty = 2 + } + private class InvocationState { public Type ReturnType; public Delegate Callback; - public bool NeedCheckReturnValue; + public CheckReturnType NeedCheckReturnValue; } /// @@ -97,11 +104,23 @@ protected override void ContextUnloaded(ExecutionContext context) if (!(UncaughtException is null)) return; if (invocationStates.Count == 0) return; if (!invocationStates.Remove(CurrentContext, out InvocationState state)) return; - if (state.NeedCheckReturnValue) - if (context.EvaluationStack.Count == 0) - Push(StackItem.Null); - else if (context.EvaluationStack.Count > 1) - throw new InvalidOperationException(); + switch (state.NeedCheckReturnValue) + { + case CheckReturnType.EnsureIsEmpty: + { + if (context.EvaluationStack.Count != 0) + throw new InvalidOperationException(); + break; + } + case CheckReturnType.EnsureNotEmpty: + { + if (context.EvaluationStack.Count == 0) + Push(StackItem.Null); + else if (context.EvaluationStack.Count > 1) + throw new InvalidOperationException(); + break; + } + } switch (state.Callback) { case null: @@ -142,7 +161,7 @@ protected override void LoadContext(ExecutionContext context) internal void LoadContext(ExecutionContext context, bool checkReturnValue) { if (checkReturnValue) - GetInvocationState(CurrentContext).NeedCheckReturnValue = true; + GetInvocationState(CurrentContext).NeedCheckReturnValue = CheckReturnType.EnsureNotEmpty; LoadContext(context); }