diff --git a/runtime/stdlib/test_emulatorbackend.go b/runtime/stdlib/test_emulatorbackend.go index da0b8aacd4..960421aa0e 100644 --- a/runtime/stdlib/test_emulatorbackend.go +++ b/runtime/stdlib/test_emulatorbackend.go @@ -639,15 +639,21 @@ func (t *testEmulatorBackendType) newEventsFunction( return interpreter.NewUnmeteredHostFunctionValue( t.eventsFunctionType, func(invocation interpreter.Invocation) interpreter.Value { - value, ok := invocation.Arguments[0].(interpreter.OptionalValue) - if !ok { - panic(errors.NewUnreachableError()) - } - var eventType interpreter.StaticType = nil - _, isNilValue := value.(interpreter.NilValue) - if !isNilValue { - eventType = value.StaticType(invocation.Interpreter) + + switch value := invocation.Arguments[0].(type) { + case interpreter.NilValue: + // Do nothing + case *interpreter.SomeValue: + innerValue := value.InnerValue(invocation.Interpreter, invocation.LocationRange) + typeValue, ok := innerValue.(interpreter.TypeValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + eventType = typeValue.Type + default: + panic(errors.NewUnreachableError()) } return testFramework.Events(invocation.Interpreter, eventType) diff --git a/runtime/stdlib/test_test.go b/runtime/stdlib/test_test.go index 0397b77ea5..7782cfe2e5 100644 --- a/runtime/stdlib/test_test.go +++ b/runtime/stdlib/test_test.go @@ -37,6 +37,14 @@ import ( ) func newTestContractInterpreter(t *testing.T, code string) (*interpreter.Interpreter, error) { + return newTestContractInterpreterWithTestFramework(t, code, nil) +} + +func newTestContractInterpreterWithTestFramework( + t *testing.T, + code string, + testFramework TestFramework, +) (*interpreter.Interpreter, error) { program, err := parser.ParseProgram( nil, []byte(code), @@ -109,7 +117,7 @@ func newTestContractInterpreter(t *testing.T, code string) (*interpreter.Interpr return nil }, - ContractValueHandler: NewTestInterpreterContractValueHandler(nil), + ContractValueHandler: NewTestInterpreterContractValueHandler(testFramework), UUIDHandler: func() (uint64, error) { uuid++ return uuid, nil @@ -1636,3 +1644,235 @@ func TestTestExpectFailure(t *testing.T) { assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) }) } + +func TestBlockchain(t *testing.T) { + + t.Parallel() + + t.Run("all events, empty", func(t *testing.T) { + t.Parallel() + + script := ` + import Test + + pub fun test(): [AnyStruct] { + var blockchain = Test.newEmulatorBlockchain() + return blockchain.events() + } + ` + + eventsInvoked := false + + testFramework := &mockedTestFramework{ + events: func(inter *interpreter.Interpreter, eventType interpreter.StaticType) interpreter.Value { + eventsInvoked = true + assert.Nil(t, eventType) + return interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.NewVariableSizedStaticType(inter, interpreter.PrimitiveStaticTypeAnyStruct), + common.Address{}, + ) + }, + } + + inter, err := newTestContractInterpreterWithTestFramework(t, script, testFramework) + require.NoError(t, err) + + _, err = inter.Invoke("test") + require.NoError(t, err) + + assert.True(t, eventsInvoked) + }) + + t.Run("typed events, empty", func(t *testing.T) { + t.Parallel() + + script := ` + import Test + + pub fun test(): [AnyStruct] { + var blockchain = Test.newEmulatorBlockchain() + + // 'Foo' is not an event-type. + // But we just need to test the API, so it doesn't really matter. + var typ = Type() + + return blockchain.eventsOfType(typ) + } + + pub struct Foo {} + ` + + eventsInvoked := false + + testFramework := &mockedTestFramework{ + events: func(inter *interpreter.Interpreter, eventType interpreter.StaticType) interpreter.Value { + eventsInvoked = true + assert.NotNil(t, eventType) + + require.IsType(t, interpreter.CompositeStaticType{}, eventType) + compositeType := eventType.(interpreter.CompositeStaticType) + assert.Equal(t, "Foo", compositeType.QualifiedIdentifier) + + return interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + interpreter.NewVariableSizedStaticType(inter, interpreter.PrimitiveStaticTypeAnyStruct), + common.Address{}, + ) + }, + } + + inter, err := newTestContractInterpreterWithTestFramework(t, script, testFramework) + require.NoError(t, err) + + _, err = inter.Invoke("test") + require.NoError(t, err) + + assert.True(t, eventsInvoked) + }) + + // TODO: Add more tests for the remaining functions. +} + +type mockedTestFramework struct { + runScript func(inter *interpreter.Interpreter, code string, arguments []interpreter.Value) + createAccount func() (*Account, error) + addTransaction func(inter *interpreter.Interpreter, code string, authorizers []common.Address, signers []*Account, arguments []interpreter.Value) error + executeTransaction func() *TransactionResult + commitBlock func() error + deployContract func(inter *interpreter.Interpreter, name string, code string, account *Account, arguments []interpreter.Value) error + readFile func(s string) (string, error) + useConfiguration func(configuration *Configuration) + stdlibHandler func() StandardLibraryHandler + logs func() []string + serviceAccount func() (*Account, error) + events func(inter *interpreter.Interpreter, eventType interpreter.StaticType) interpreter.Value + reset func() +} + +var _ TestFramework = &mockedTestFramework{} + +func (m mockedTestFramework) RunScript( + inter *interpreter.Interpreter, + code string, + arguments []interpreter.Value, +) *ScriptResult { + if m.runScript == nil { + panic("'RunScript' is not implemented") + } + + return m.RunScript(inter, code, arguments) +} + +func (m mockedTestFramework) CreateAccount() (*Account, error) { + if m.createAccount == nil { + panic("'CreateAccount' is not implemented") + } + + return m.createAccount() +} + +func (m mockedTestFramework) AddTransaction( + inter *interpreter.Interpreter, + code string, + authorizers []common.Address, + signers []*Account, + arguments []interpreter.Value, +) error { + if m.addTransaction == nil { + panic("'AddTransaction' is not implemented") + } + + return m.addTransaction(inter, code, authorizers, signers, arguments) +} + +func (m mockedTestFramework) ExecuteNextTransaction() *TransactionResult { + if m.executeTransaction == nil { + panic("'ExecuteNextTransaction' is not implemented") + } + + return m.executeTransaction() +} + +func (m mockedTestFramework) CommitBlock() error { + if m.commitBlock == nil { + panic("'CommitBlock' is not implemented") + } + + return m.commitBlock() +} + +func (m mockedTestFramework) DeployContract( + inter *interpreter.Interpreter, + name string, + code string, + account *Account, + arguments []interpreter.Value, +) error { + if m.deployContract == nil { + panic("'DeployContract' is not implemented") + } + + return m.deployContract(inter, name, code, account, arguments) +} + +func (m mockedTestFramework) ReadFile(fileName string) (string, error) { + if m.readFile == nil { + panic("'ReadFile' is not implemented") + } + + return m.readFile(fileName) +} + +func (m mockedTestFramework) UseConfiguration(configuration *Configuration) { + if m.useConfiguration == nil { + panic("'UseConfiguration' is not implemented") + } + + m.useConfiguration(configuration) +} + +func (m mockedTestFramework) StandardLibraryHandler() StandardLibraryHandler { + if m.stdlibHandler == nil { + panic("'StandardLibraryHandler' is not implemented") + } + + return m.stdlibHandler() +} + +func (m mockedTestFramework) Logs() []string { + if m.logs == nil { + panic("'Logs' is not implemented") + } + + return m.logs() +} + +func (m mockedTestFramework) ServiceAccount() (*Account, error) { + if m.serviceAccount == nil { + panic("'ServiceAccount' is not implemented") + } + + return m.serviceAccount() +} + +func (m mockedTestFramework) Events( + inter *interpreter.Interpreter, + eventType interpreter.StaticType, +) interpreter.Value { + if m.events == nil { + panic("'Events' is not implemented") + } + + return m.events(inter, eventType) +} + +func (m mockedTestFramework) Reset() { + if m.reset == nil { + panic("'Reset' is not implemented") + } + + m.reset() +}