Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cmd: add watch mode to opa test command #5812

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 99 additions & 15 deletions cmd/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ package cmd
import (
"context"
"fmt"
"io"
"os"
"os/signal"
"strings"
"syscall"
"time"

"github.com/spf13/cobra"
Expand All @@ -16,7 +20,11 @@ import (
"github.com/open-policy-agent/opa/bundle"
"github.com/open-policy-agent/opa/compile"
"github.com/open-policy-agent/opa/cover"
"github.com/open-policy-agent/opa/filewatcher"
"github.com/open-policy-agent/opa/internal/runtime"
initload "github.com/open-policy-agent/opa/internal/runtime/init"
"github.com/open-policy-agent/opa/loader"
"github.com/open-policy-agent/opa/logging"
"github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/storage/inmem"
"github.com/open-policy-agent/opa/tester"
Expand Down Expand Up @@ -47,6 +55,9 @@ type testCommandParams struct {
target *util.EnumFlag
skipExitZero bool
capabilities *capabilitiesFlag
watch bool
output io.Writer
killChan chan os.Signal
}

func newTestCommandParams() *testCommandParams {
Expand All @@ -55,6 +66,8 @@ func newTestCommandParams() *testCommandParams {
explain: newExplainFlag([]string{explainModeFails, explainModeFull, explainModeNotes, explainModeDebug}),
target: util.NewEnumFlag(compile.TargetRego, []string{compile.TargetRego, compile.TargetWasm}),
capabilities: newcapabilitiesFlag(),
output: os.Stdout,
killChan: make(chan os.Signal, 1),
}
}

Expand Down Expand Up @@ -148,10 +161,66 @@ The optional "gobench" output format conforms to the Go Benchmark Data Format.
},
}

func opaTest(args []string) int {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
func newOnReload(c chan int) filewatcher.OnReload {

onReload := func(ctx context.Context, txn storage.Transaction, d time.Duration, s storage.Store, l *initload.LoadPathsResult, err error) {
notify := func() {
c <- 1
}
defer notify()

if err != nil {
fmt.Printf("error reloading files: %v\n", err)
return
}

// FIXME: We don't detect when data files are removed.
if len(l.Files.Documents) > 0 {
if err := s.Write(ctx, txn, storage.AddOp, storage.Path{}, l.Files.Documents); err != nil {
fmt.Printf("storage error: %v\n", err)
return
}
}

modules := map[string]*ast.Module{}
for id, module := range l.Files.Modules {
modules[id] = module.Parsed
}

compileAndRunTests(ctx, txn, s, modules, l.Bundles)
}

return onReload
}

func watchTests(ctx context.Context, paths []string, filter loader.Filter, bundleMode bool, store storage.Store) int {
reloadChan := make(chan int)
onReload := newOnReload(reloadChan)

signal.Notify(testParams.killChan, syscall.SIGINT, syscall.SIGTERM)

logger := logging.New()

w := filewatcher.NewFileWatcher(paths, filter, bundleMode, store, onReload, logger)
err := w.Start(ctx)
if err != nil {
fmt.Fprintln(os.Stderr, "error", err)
return 1
}

for {
fmt.Fprintln(testParams.output, strings.Repeat("*", 80))
fmt.Fprintln(testParams.output, "Watching for changes ...")
select {
case <-testParams.killChan:
return 0
case <-reloadChan:
break
}
}
}

func opaTest(args []string) int {
if testParams.outputFormat.String() == benchmarkGoBenchOutput && !testParams.benchmark {
fmt.Fprintf(os.Stderr, "cannot use output format %s without running benchmarks (--bench)\n", benchmarkGoBenchOutput)
return 0
Expand All @@ -166,30 +235,44 @@ func opaTest(args []string) int {
Ignore: testParams.ignore,
}

var modules map[string]*ast.Module
var bundles map[string]*bundle.Bundle
var store storage.Store
var err error

if testParams.bundleMode {
bundles, err = tester.LoadBundles(args, filter.Apply)
store = inmem.NewWithOpts(inmem.OptRoundTripOnWrite(false))
} else {
modules, store, err = tester.Load(args, filter.Apply)
}

result, err := initload.LoadPaths(args, filter.Apply, testParams.bundleMode, nil, true, false, nil)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 1
}

store = inmem.NewFromObjectWithOpts(result.Files.Documents, inmem.OptRoundTripOnWrite(false))

modules := map[string]*ast.Module{}
for _, m := range result.Files.Modules {
modules[m.Name] = m.Parsed
}

bundles := result.Bundles

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

txn, err := store.NewTransaction(ctx, storage.WriteParams)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 1
}

defer store.Abort(ctx, txn)
if testParams.watch {
compileAndRunTests(ctx, txn, store, modules, bundles)
store.Commit(ctx, txn)
return watchTests(ctx, args, filter.Apply, testParams.bundleMode, store)
} else {
defer store.Abort(ctx, txn)
return compileAndRunTests(ctx, txn, store, modules, bundles)
}
}

func compileAndRunTests(ctx context.Context, txn storage.Transaction, store storage.Store, modules map[string]*ast.Module, bundles map[string]*bundle.Bundle) int {

var capabilities *ast.Capabilities
// if capabilities are not provided as a cmd flag,
Expand Down Expand Up @@ -266,7 +349,7 @@ func opaTest(args []string) int {
default:
reporter = tester.PrettyReporter{
Verbose: testParams.verbose,
Output: os.Stdout,
Output: testParams.output,
BenchmarkResults: testParams.benchmark,
BenchMarkShowAllocations: testParams.benchMem,
BenchMarkGoBenchFormat: goBench,
Expand All @@ -276,7 +359,7 @@ func opaTest(args []string) int {
reporter = tester.JSONCoverageReporter{
Cover: cov,
Modules: modules,
Output: os.Stdout,
Output: testParams.output,
Threshold: testParams.threshold,
}
}
Expand Down Expand Up @@ -386,5 +469,6 @@ func init() {
setExplainFlag(testCommand.Flags(), testParams.explain)
addTargetFlag(testCommand.Flags(), testParams.target)
addCapabilitiesFlag(testCommand.Flags(), testParams.capabilities)
testCommand.Flags().BoolVarP(&testParams.watch, "watch", "w", false, "watch for file changes and re-run tests")
RootCommand.AddCommand(testCommand)
}
139 changes: 139 additions & 0 deletions cmd/test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@ package cmd
import (
"bytes"
"context"
"os"
"path"
"path/filepath"
"regexp"
"strings"
"syscall"
"testing"
"time"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
Expand Down Expand Up @@ -272,3 +278,136 @@ test_p {
t.Fatalf("didn't get expected %s error when inlined schema is present; got: %v", ast.TypeErr, err)
}
}

func TestWatchMode(t *testing.T) {

var buff bytes.Buffer
testParams.output = &buff

policyPath := "test/policy.rego"
dataPath := "test/data.json"

fs := map[string]string{
policyPath: `package test

x := 1
`,
"test/tests.rego": `package test

import data.y

test_x {
x == 1
}

test_y {
y == 2
}
`,
dataPath: `{"y": 2}`,
}

updatedPolicy1 := `package test

x := 3
`

newPolicyPath := "test/tests2.rego"
newPolicy := `package added_test

foo := "bar"

test_foo {
foo == "baz"
}
`

updatedData := `{"y": 3}`

expectedOutput := `PASS: 2/2
********************************************************************************
Watching for changes ...
%ROOT%/test/tests.rego:
data.test.test_x: FAIL (%TIME%)
--------------------------------------------------------------------------------
PASS: 1/2
FAIL: 1/2
********************************************************************************
Watching for changes ...
%ROOT%/test/tests.rego:
data.test.test_x: FAIL (%TIME%)
data.test.test_y: FAIL (%TIME%)
--------------------------------------------------------------------------------
FAIL: 2/2
********************************************************************************
Watching for changes ...
%ROOT%/test/tests.rego:
data.test.test_x: FAIL (%TIME%)
data.test.test_y: FAIL (%TIME%)

%ROOT%/test/tests2.rego:
data.added_test.test_foo: FAIL (%TIME%)
--------------------------------------------------------------------------------
FAIL: 3/3
********************************************************************************
Watching for changes ...
%ROOT%/test/tests.rego:
data.test.test_x: FAIL (%TIME%)
data.test.test_y: FAIL (%TIME%)
--------------------------------------------------------------------------------
FAIL: 2/2
********************************************************************************
Watching for changes ...
`

test.WithTempFS(fs, func(p string) {
testParams.watch = true

rootDir := filepath.Join(p, "test")

go opaTest([]string{rootDir})

time.Sleep(1 * time.Second)

// Update policy file
if err := os.WriteFile(path.Join(p, policyPath), []byte(updatedPolicy1), 0644); err != nil {
t.Fatal(err)
}

time.Sleep(1 * time.Second)

// Update data file
if err := os.WriteFile(path.Join(p, dataPath), []byte(updatedData), 0644); err != nil {
t.Fatal(err)
}

time.Sleep(1 * time.Second)

// add new policy file
if err := os.WriteFile(path.Join(p, newPolicyPath), []byte(newPolicy), 0644); err != nil {
t.Fatal(err)
}

time.Sleep(1 * time.Second)

// remove added policy file
if err := os.Remove(path.Join(p, newPolicyPath)); err != nil {
t.Fatal(err)
}

time.Sleep(1 * time.Second)

// TODO: Test adding and removing data files (currently not supported)

testParams.killChan <- syscall.SIGINT

r := regexp.MustCompile(`FAIL \(.*s\)`)
actualOutput := r.ReplaceAllString(buff.String(), "FAIL (%TIME%)")

expectedOutput = strings.ReplaceAll(expectedOutput, "%ROOT%", p)

if expectedOutput != actualOutput {
t.Fatalf("Expected:\n\n%s\n\nGot:\n\n%s\n\n", expectedOutput, actualOutput)
}
})
}
Loading