Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor output in case of error for async callback and empty function name check #840

Merged
74 changes: 74 additions & 0 deletions vmhost/contexts/async_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package contexts

import (
"errors"
"github.com/multiversx/mx-chain-core-go/core"
"math/big"
"testing"

Expand Down Expand Up @@ -496,6 +497,79 @@ func TestAsyncContext_UpdateCurrentCallStatus(t *testing.T) {
require.Equal(t, vmhost.AsyncCallRejected, asyncCall.Status)
}

func TestAsyncContext_OutputInCaseOfErrorInCallback(t *testing.T) {
user := []byte("user")
contractA := []byte("contractA")
contractB := []byte("contractB")

host, _ := initializeVMAndWasmerAsyncContext(t)
host.EnableEpochsHandlerField = &worldmock.EnableEpochsHandlerStub{
IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool {
return flag == vmhost.AsyncV3Flag
},
}

async := makeAsyncContext(t, host, contractA)
host.Storage().SetAddress(contractA)
host.AsyncContext = async

vmInput := &vmcommon.ContractCallInput{
VMInput: vmcommon.VMInput{
CallerAddr: user,
Arguments: [][]byte{{0}},
CallType: vm.DirectCall,
},
RecipientAddr: contractA,
}
host.Runtime().InitStateFromContractCallInput(vmInput)

err := async.RegisterAsyncCall("", &vmhost.AsyncCall{
Destination: contractB,
Data: []byte("function"),
})
require.Nil(t, err)

err = async.Save()
require.Nil(t, err)

asyncCallId := async.GetCallID()
asyncStoragePrefix := host.Storage().GetVmProtectedPrefix(vmhost.AsyncDataPrefix)
asyncCallKey := vmhost.CustomStorageKey(string(asyncStoragePrefix), asyncCallId)

data, _, _, _ := host.Storage().GetStorageUnmetered(asyncCallKey)
require.NotEqual(t, len(data), 0)

vmInput = &vmcommon.ContractCallInput{
VMInput: vmcommon.VMInput{
CallerAddr: contractB,
Arguments: [][]byte{{0}},
CallType: vm.AsynchronousCallBack,
},
RecipientAddr: contractA,
}
host.Runtime().InitStateFromContractCallInput(vmInput)

async.callbackAsyncInitiatorCallID = asyncCallId
async.callType = vmInput.CallType
err = async.LoadParentContext()
require.Nil(t, err)

vmOutput := host.Output().CreateVMOutputInCaseOfError(vmhost.ErrNotEnoughGas)
outputAccount := vmOutput.OutputAccounts[string(contractA)]

require.NotNil(t, outputAccount)

storageUpdates := outputAccount.StorageUpdates
require.Equal(t, len(storageUpdates), 1)

asyncContextDeletionUpdate := storageUpdates[string(asyncCallKey)]
require.NotNil(t, asyncContextDeletionUpdate)
require.Equal(t, len(asyncContextDeletionUpdate.Data), 0)

data, _, _, _ = host.Storage().GetStorageUnmetered(asyncCallKey)
require.Equal(t, len(data), 0)
}

func TestAsyncContext_SendAsyncCallCrossShard(t *testing.T) {
host, world := initializeVMAndWasmerAsyncContext(t)
world.AcctMap.PutAccount(&worldmock.Account{
Expand Down
35 changes: 34 additions & 1 deletion vmhost/contexts/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -562,22 +562,55 @@ func (context *outputContext) DeployCode(input vmhost.CodeDeployInput) {
context.codeUpdates[string(input.ContractAddress)] = empty
}

// createVMOutputInCaseOfErrorOfAsyncCallback appends the deletion of the async context to the output
func (context *outputContext) createVMOutputInCaseOfErrorOfAsyncCallback(returnCode vmcommon.ReturnCode, returnMessage string) *vmcommon.VMOutput {
async := context.host.Async()
metering := context.host.Metering()

callId := async.GetCallbackAsyncInitiatorCallID()

context.outputState = &vmcommon.VMOutput{
GasRemaining: 0,
GasRefund: big.NewInt(0),
ReturnCode: returnCode,
ReturnMessage: returnMessage,
OutputAccounts: make(map[string]*vmcommon.OutputAccount),
}

err := async.DeleteFromCallID(callId)
if err != nil {
logOutput.Trace("failed to delete Async Context", "callId", callId, "err", err)
}

metering.UpdateGasStateOnFailure(context.outputState)

return context.outputState
}

// CreateVMOutputInCaseOfError creates a new vmOutput with the given error set as return message.
func (context *outputContext) CreateVMOutputInCaseOfError(err error) *vmcommon.VMOutput {
runtime := context.host.Runtime()
metering := context.host.Metering()

callType := runtime.GetVMInput().CallType

runtime.AddError(err, runtime.FunctionName())

returnCode := context.resolveReturnCodeFromError(err)
returnMessage := context.resolveReturnMessageFromError(err)

if context.host.EnableEpochsHandler().IsFlagEnabled(vmhost.AsyncV3Flag) && callType == vm.AsynchronousCallBack {
return context.createVMOutputInCaseOfErrorOfAsyncCallback(returnCode, returnMessage)
}

vmOutput := &vmcommon.VMOutput{
GasRemaining: 0,
GasRefund: big.NewInt(0),
ReturnCode: returnCode,
ReturnMessage: returnMessage,
}

context.host.Metering().UpdateGasStateOnFailure(vmOutput)
metering.UpdateGasStateOnFailure(vmOutput)

return vmOutput
}
Expand Down
2 changes: 2 additions & 0 deletions vmhost/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ const (
FixOOGReturnCodeFlag core.EnableEpochFlag = "FixOOGReturnCodeFlag"
// DynamicGasCostForDataTrieStorageLoadFlag defines the flag that activates the dynamic gas cost for data trie storage load
DynamicGasCostForDataTrieStorageLoadFlag core.EnableEpochFlag = "DynamicGasCostForDataTrieStorageLoadFlag"
// AsyncV3Flag defines the flag that activates async v3
AsyncV3Flag core.EnableEpochFlag = "AsyncV3Flag"
)
96 changes: 52 additions & 44 deletions vmhost/hostCore/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,14 @@ func (host *vmHost) checkFinalGasAfterExit() error {
return nil
}

func (host *vmHost) checkValidFunctionName(name string) error {
if name == "" {
return executor.ErrInvalidFunction
}

return nil
}

func (host *vmHost) callInitFunction() error {
return host.callSCFunction(vmhost.InitFunctionName)
}
Expand All @@ -1154,12 +1162,18 @@ func (host *vmHost) callUpgradeFunction() error {
}

func (host *vmHost) callSCFunction(functionName string) error {
err := host.checkValidFunctionName(functionName)
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "checkValidFunctionName")
return err
}

runtime := host.Runtime()
if !runtime.HasFunction(functionName) {
return executor.ErrFuncNotFound
}

err := runtime.CallSCFunction(functionName)
err = runtime.CallSCFunction(functionName)
if err != nil {
err = host.handleBreakpointIfAny(err)
}
Expand Down Expand Up @@ -1236,12 +1250,6 @@ func (host *vmHost) callSCMethodAsynchronousCallBack() error {
metering.UseGas(metering.GasLeft())
}

// TODO matei-p R2 Returning an error here will cause the VMOutput to be
// empty (due to CreateVMOutputInCaseOfError()). But in release 2 of
// Promises, CreateVMOutputInCaseOfError() should still contain storage
// deletions caused by AsyncContext cleanup, even if callbackErr != nil and
// was returned here. The storage deletions MUST be persisted in the data
// trie once R2 goes live.
if !isCallComplete {
return callbackErr
}
Expand All @@ -1263,47 +1271,47 @@ func (host *vmHost) callFunctionAndExecuteAsync() (bool, error) {
runtime := host.Runtime()
async := host.Async()

// TODO refactor this, and apply this condition in other places where a
// function is called
if runtime.FunctionName() != "" {
err := host.verifyAllowedFunctionCall()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "verifyAllowedFunctionCall")
return false, err
}
err := host.checkValidFunctionName(runtime.FunctionName())
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "checkValidFunctionName")
return false, err
}

functionName, err := runtime.FunctionNameChecked()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "FunctionNameChecked")
return false, err
}
err = host.verifyAllowedFunctionCall()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "verifyAllowedFunctionCall")
return false, err
}

err = runtime.CallSCFunction(functionName)
if err != nil {
err = host.handleBreakpointIfAny(err)
log.Trace("breakpoint detected and handled", "err", err)
}
if err == nil {
err = host.checkFinalGasAfterExit()
}
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "sc function")
return true, err
}
functionName, err := runtime.FunctionNameChecked()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "FunctionNameChecked")
return false, err
}

err = async.Execute()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "async execution")
return false, err
}
err = runtime.CallSCFunction(functionName)
if err != nil {
err = host.handleBreakpointIfAny(err)
log.Trace("breakpoint detected and handled", "err", err)
}
if err == nil {
err = host.checkFinalGasAfterExit()
}
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "sc function")
return true, err
}

if !async.IsComplete() || async.HasLegacyGroup() {
async.SetResults(host.Output().GetVMOutput())
err = async.Save()
return false, err
}
} else {
return false, executor.ErrInvalidFunction
err = async.Execute()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "async execution")
return false, err
}

if !async.IsComplete() || async.HasLegacyGroup() {
async.SetResults(host.Output().GetVMOutput())
err = async.Save()
return false, err
}

return true, nil
Expand Down