diff --git a/test/emulator_backend.go b/test/emulator_backend.go index 49823732..f7a86c0c 100644 --- a/test/emulator_backend.go +++ b/test/emulator_backend.go @@ -493,3 +493,13 @@ func (e *EmulatorBackend) replaceImports(code string) string { func (e *EmulatorBackend) StandardLibraryHandler() stdlib.StandardLibraryHandler { return e.stdlibHandler } + +func (e *EmulatorBackend) Reset() { + err := e.blockchain.ReloadBlockchain() + if err != nil { + panic(err) + } + + // Reset the transaction offset. + e.blockOffset = 0 +} diff --git a/test/test_framework_test.go b/test/test_framework_test.go index e6d83803..c6c9e65d 100644 --- a/test/test_framework_test.go +++ b/test/test_framework_test.go @@ -1195,6 +1195,115 @@ func TestSetupAndTearDown(t *testing.T) { }) } +func TestBeforeAndAfterEach(t *testing.T) { + t.Parallel() + + t.Run("beforeEach", func(t *testing.T) { + t.Parallel() + + code := ` + pub(set) var counter = 0 + + pub fun beforeEach() { + counter = counter + 1 + } + + pub fun testFuncOne() { + assert(counter == 1) + } + + pub fun testFuncTwo() { + assert(counter == 2) + } + ` + + runner := NewTestRunner() + results, err := runner.RunTests(code) + require.NoError(t, err) + + require.Len(t, results, 2) + assert.Equal(t, results[0].TestName, "testFuncOne") + require.NoError(t, results[0].Error) + assert.Equal(t, results[1].TestName, "testFuncTwo") + require.NoError(t, results[1].Error) + }) + + t.Run("beforeEach failed", func(t *testing.T) { + t.Parallel() + + code := ` + pub fun beforeEach() { + panic("error occurred") + } + + pub fun testFunc() { + assert(true) + } + ` + + runner := NewTestRunner() + results, err := runner.RunTests(code) + require.Error(t, err) + require.Empty(t, results) + }) + + t.Run("afterEach", func(t *testing.T) { + t.Parallel() + + code := ` + pub(set) var counter = 2 + + pub fun afterEach() { + counter = counter - 1 + } + + pub fun testFuncOne() { + assert(counter == 2) + } + + pub fun testFuncTwo() { + assert(counter == 1) + } + + pub fun tearDown() { + assert(counter == 0) + } + ` + + runner := NewTestRunner() + results, err := runner.RunTests(code) + require.NoError(t, err) + + require.Len(t, results, 2) + assert.Equal(t, results[0].TestName, "testFuncOne") + require.NoError(t, results[0].Error) + assert.Equal(t, results[1].TestName, "testFuncTwo") + require.NoError(t, results[1].Error) + }) + + t.Run("afterEach failed", func(t *testing.T) { + t.Parallel() + + code := ` + pub(set) var tearDownRan = false + + pub fun testFunc() { + assert(!tearDownRan) + } + + pub fun afterEach() { + assert(false) + } + ` + + runner := NewTestRunner() + results, err := runner.RunTests(code) + + require.Error(t, err) + require.Len(t, results, 0) + }) +} + func TestPrettyPrintTestResults(t *testing.T) { t.Parallel() diff --git a/test/test_runner.go b/test/test_runner.go index 3cbbe835..17d10ecc 100644 --- a/test/test_runner.go +++ b/test/test_runner.go @@ -52,6 +52,10 @@ const setupFunctionName = "setup" const tearDownFunctionName = "tearDown" +const beforeEachFunctionName = "beforeEach" + +const afterEachFunctionName = "afterEach" + var testScriptLocation = common.NewScriptLocation(nil, []byte("test")) type Results []Result @@ -128,8 +132,20 @@ func (r *TestRunner) RunTest(script string, funcName string) (result *Result, er return nil, err } + // Run `beforeEach()` before running the test function. + err = r.runBeforeEach(inter) + if err != nil { + return nil, err + } + _, testResult := inter.Invoke(funcName) + // Run `afterEach()` after running the test function. + err = r.runAfterEach(inter) + if err != nil { + return nil, err + } + // Run test `tearDown()` once running all test functions are completed. err = r.runTestTearDown(inter) @@ -167,11 +183,23 @@ func (r *TestRunner) RunTests(script string) (results Results, err error) { continue } - err := r.invokeTestFunction(inter, funcName) + // Run `beforeEach()` before running the test function. + err = r.runBeforeEach(inter) + if err != nil { + return nil, err + } + + testErr := r.invokeTestFunction(inter, funcName) + + // Run `afterEach()` after running the test function. + err = r.runAfterEach(inter) + if err != nil { + return nil, err + } results = append(results, Result{ TestName: funcName, - Error: err, + Error: testErr, }) } @@ -205,6 +233,30 @@ func hasTearDown(inter *interpreter.Interpreter) bool { return inter.Globals.Contains(tearDownFunctionName) } +func (r *TestRunner) runBeforeEach(inter *interpreter.Interpreter) error { + if !hasBeforeEach(inter) { + return nil + } + + return r.invokeTestFunction(inter, beforeEachFunctionName) +} + +func hasBeforeEach(inter *interpreter.Interpreter) bool { + return inter.Globals.Contains(beforeEachFunctionName) +} + +func (r *TestRunner) runAfterEach(inter *interpreter.Interpreter) error { + if !hasAfterEach(inter) { + return nil + } + + return r.invokeTestFunction(inter, afterEachFunctionName) +} + +func hasAfterEach(inter *interpreter.Interpreter) bool { + return inter.Globals.Contains(afterEachFunctionName) +} + func (r *TestRunner) invokeTestFunction(inter *interpreter.Interpreter, funcName string) (err error) { // Individually fail each test-case for any internal error. defer func() {