diff --git a/cmd/test.go b/cmd/test.go index fa31bab603..bfb2d4fc1a 100644 --- a/cmd/test.go +++ b/cmd/test.go @@ -7,7 +7,11 @@ package cmd import ( "context" "fmt" + "io" "os" + "os/signal" + "strings" + "syscall" "time" "github.com/spf13/cobra" @@ -16,7 +20,11 @@ import ( "github.com/open-policy-agent/opa/bundle" "github.com/open-policy-agent/opa/compile" "github.com/open-policy-agent/opa/cover" + "github.com/open-policy-agent/opa/filewatcher" "github.com/open-policy-agent/opa/internal/runtime" + initload "github.com/open-policy-agent/opa/internal/runtime/init" + "github.com/open-policy-agent/opa/loader" + "github.com/open-policy-agent/opa/logging" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/storage/inmem" "github.com/open-policy-agent/opa/tester" @@ -47,6 +55,9 @@ type testCommandParams struct { target *util.EnumFlag skipExitZero bool capabilities *capabilitiesFlag + watch bool + output io.Writer + killChan chan os.Signal } func newTestCommandParams() *testCommandParams { @@ -55,6 +66,8 @@ func newTestCommandParams() *testCommandParams { explain: newExplainFlag([]string{explainModeFails, explainModeFull, explainModeNotes, explainModeDebug}), target: util.NewEnumFlag(compile.TargetRego, []string{compile.TargetRego, compile.TargetWasm}), capabilities: newcapabilitiesFlag(), + output: os.Stdout, + killChan: make(chan os.Signal, 1), } } @@ -148,10 +161,66 @@ The optional "gobench" output format conforms to the Go Benchmark Data Format. }, } -func opaTest(args []string) int { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() +func newOnReload(c chan int) filewatcher.OnReload { + + onReload := func(ctx context.Context, txn storage.Transaction, d time.Duration, s storage.Store, l *initload.LoadPathsResult, err error) { + notify := func() { + c <- 1 + } + defer notify() + + if err != nil { + fmt.Printf("error reloading files: %v\n", err) + return + } + + // FIXME: We don't detect when data files are removed. + if len(l.Files.Documents) > 0 { + if err := s.Write(ctx, txn, storage.AddOp, storage.Path{}, l.Files.Documents); err != nil { + fmt.Printf("storage error: %v\n", err) + return + } + } + + modules := map[string]*ast.Module{} + for id, module := range l.Files.Modules { + modules[id] = module.Parsed + } + + compileAndRunTests(ctx, txn, s, modules, l.Bundles) + } + + return onReload +} + +func watchTests(ctx context.Context, paths []string, filter loader.Filter, bundleMode bool, store storage.Store) int { + reloadChan := make(chan int) + onReload := newOnReload(reloadChan) + + signal.Notify(testParams.killChan, syscall.SIGINT, syscall.SIGTERM) + + logger := logging.New() + + w := filewatcher.NewFileWatcher(paths, filter, bundleMode, store, onReload, logger) + err := w.Start(ctx) + if err != nil { + fmt.Fprintln(os.Stderr, "error", err) + return 1 + } + + for { + fmt.Fprintln(testParams.output, strings.Repeat("*", 80)) + fmt.Fprintln(testParams.output, "Watching for changes ...") + select { + case <-testParams.killChan: + return 0 + case <-reloadChan: + break + } + } +} +func opaTest(args []string) int { if testParams.outputFormat.String() == benchmarkGoBenchOutput && !testParams.benchmark { fmt.Fprintf(os.Stderr, "cannot use output format %s without running benchmarks (--bench)\n", benchmarkGoBenchOutput) return 0 @@ -166,30 +235,44 @@ func opaTest(args []string) int { Ignore: testParams.ignore, } - var modules map[string]*ast.Module - var bundles map[string]*bundle.Bundle var store storage.Store var err error - if testParams.bundleMode { - bundles, err = tester.LoadBundles(args, filter.Apply) - store = inmem.NewWithOpts(inmem.OptRoundTripOnWrite(false)) - } else { - modules, store, err = tester.Load(args, filter.Apply) - } - + result, err := initload.LoadPaths(args, filter.Apply, testParams.bundleMode, nil, true, false, nil) if err != nil { fmt.Fprintln(os.Stderr, err) return 1 } + store = inmem.NewFromObjectWithOpts(result.Files.Documents, inmem.OptRoundTripOnWrite(false)) + + modules := map[string]*ast.Module{} + for _, m := range result.Files.Modules { + modules[m.Name] = m.Parsed + } + + bundles := result.Bundles + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + txn, err := store.NewTransaction(ctx, storage.WriteParams) if err != nil { fmt.Fprintln(os.Stderr, err) return 1 } - defer store.Abort(ctx, txn) + if testParams.watch { + compileAndRunTests(ctx, txn, store, modules, bundles) + store.Commit(ctx, txn) + return watchTests(ctx, args, filter.Apply, testParams.bundleMode, store) + } else { + defer store.Abort(ctx, txn) + return compileAndRunTests(ctx, txn, store, modules, bundles) + } +} + +func compileAndRunTests(ctx context.Context, txn storage.Transaction, store storage.Store, modules map[string]*ast.Module, bundles map[string]*bundle.Bundle) int { var capabilities *ast.Capabilities // if capabilities are not provided as a cmd flag, @@ -266,7 +349,7 @@ func opaTest(args []string) int { default: reporter = tester.PrettyReporter{ Verbose: testParams.verbose, - Output: os.Stdout, + Output: testParams.output, BenchmarkResults: testParams.benchmark, BenchMarkShowAllocations: testParams.benchMem, BenchMarkGoBenchFormat: goBench, @@ -276,7 +359,7 @@ func opaTest(args []string) int { reporter = tester.JSONCoverageReporter{ Cover: cov, Modules: modules, - Output: os.Stdout, + Output: testParams.output, Threshold: testParams.threshold, } } @@ -386,5 +469,6 @@ func init() { setExplainFlag(testCommand.Flags(), testParams.explain) addTargetFlag(testCommand.Flags(), testParams.target) addCapabilitiesFlag(testCommand.Flags(), testParams.capabilities) + testCommand.Flags().BoolVarP(&testParams.watch, "watch", "w", false, "watch for file changes and re-run tests") RootCommand.AddCommand(testCommand) } diff --git a/cmd/test_test.go b/cmd/test_test.go index f0ab9554a5..d80654db7f 100644 --- a/cmd/test_test.go +++ b/cmd/test_test.go @@ -3,8 +3,14 @@ package cmd import ( "bytes" "context" + "os" + "path" + "path/filepath" + "regexp" "strings" + "syscall" "testing" + "time" "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/rego" @@ -272,3 +278,136 @@ test_p { t.Fatalf("didn't get expected %s error when inlined schema is present; got: %v", ast.TypeErr, err) } } + +func TestWatchMode(t *testing.T) { + + var buff bytes.Buffer + testParams.output = &buff + + policyPath := "test/policy.rego" + dataPath := "test/data.json" + + fs := map[string]string{ + policyPath: `package test + + x := 1 + `, + "test/tests.rego": `package test + + import data.y + + test_x { + x == 1 + } + + test_y { + y == 2 + } + `, + dataPath: `{"y": 2}`, + } + + updatedPolicy1 := `package test + + x := 3 + ` + + newPolicyPath := "test/tests2.rego" + newPolicy := `package added_test + + foo := "bar" + + test_foo { + foo == "baz" + } + ` + + updatedData := `{"y": 3}` + + expectedOutput := `PASS: 2/2 +******************************************************************************** +Watching for changes ... +%ROOT%/test/tests.rego: +data.test.test_x: FAIL (%TIME%) +-------------------------------------------------------------------------------- +PASS: 1/2 +FAIL: 1/2 +******************************************************************************** +Watching for changes ... +%ROOT%/test/tests.rego: +data.test.test_x: FAIL (%TIME%) +data.test.test_y: FAIL (%TIME%) +-------------------------------------------------------------------------------- +FAIL: 2/2 +******************************************************************************** +Watching for changes ... +%ROOT%/test/tests.rego: +data.test.test_x: FAIL (%TIME%) +data.test.test_y: FAIL (%TIME%) + +%ROOT%/test/tests2.rego: +data.added_test.test_foo: FAIL (%TIME%) +-------------------------------------------------------------------------------- +FAIL: 3/3 +******************************************************************************** +Watching for changes ... +%ROOT%/test/tests.rego: +data.test.test_x: FAIL (%TIME%) +data.test.test_y: FAIL (%TIME%) +-------------------------------------------------------------------------------- +FAIL: 2/2 +******************************************************************************** +Watching for changes ... +` + + test.WithTempFS(fs, func(p string) { + testParams.watch = true + + rootDir := filepath.Join(p, "test") + + go opaTest([]string{rootDir}) + + time.Sleep(1 * time.Second) + + // Update policy file + if err := os.WriteFile(path.Join(p, policyPath), []byte(updatedPolicy1), 0644); err != nil { + t.Fatal(err) + } + + time.Sleep(1 * time.Second) + + // Update data file + if err := os.WriteFile(path.Join(p, dataPath), []byte(updatedData), 0644); err != nil { + t.Fatal(err) + } + + time.Sleep(1 * time.Second) + + // add new policy file + if err := os.WriteFile(path.Join(p, newPolicyPath), []byte(newPolicy), 0644); err != nil { + t.Fatal(err) + } + + time.Sleep(1 * time.Second) + + // remove added policy file + if err := os.Remove(path.Join(p, newPolicyPath)); err != nil { + t.Fatal(err) + } + + time.Sleep(1 * time.Second) + + // TODO: Test adding and removing data files (currently not supported) + + testParams.killChan <- syscall.SIGINT + + r := regexp.MustCompile(`FAIL \(.*s\)`) + actualOutput := r.ReplaceAllString(buff.String(), "FAIL (%TIME%)") + + expectedOutput = strings.ReplaceAll(expectedOutput, "%ROOT%", p) + + if expectedOutput != actualOutput { + t.Fatalf("Expected:\n\n%s\n\nGot:\n\n%s\n\n", expectedOutput, actualOutput) + } + }) +} diff --git a/filewatcher/filewatcher.go b/filewatcher/filewatcher.go new file mode 100644 index 0000000000..655a917b9f --- /dev/null +++ b/filewatcher/filewatcher.go @@ -0,0 +1,159 @@ +// Copyright 2023 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package filewatcher + +import ( + "context" + "time" + + "github.com/fsnotify/fsnotify" + + "github.com/open-policy-agent/opa/ast" + initload "github.com/open-policy-agent/opa/internal/runtime/init" + "github.com/open-policy-agent/opa/loader" + "github.com/open-policy-agent/opa/logging" + "github.com/open-policy-agent/opa/storage" +) + +type OnReload func(context.Context, storage.Transaction, time.Duration, storage.Store, *initload.LoadPathsResult, error) + +type FileWatcher struct { + paths []string + filter loader.Filter + bundleMode bool + store storage.Store + onReload OnReload + logger logging.Logger +} + +func NewFileWatcher(paths []string, filter loader.Filter, bundleMode bool, store storage.Store, onReload OnReload, logger logging.Logger) *FileWatcher { + return &FileWatcher{ + paths: paths, + filter: filter, + bundleMode: bundleMode, + store: store, + onReload: onReload, + logger: logger, + } +} + +func (w *FileWatcher) Start(ctx context.Context) error { + watcher, err := w.getWatcher(w.paths) + if err != nil { + return err + } + go w.readWatcher(ctx, watcher) + return nil +} + +func (w *FileWatcher) getWatcher(rootPaths []string) (*fsnotify.Watcher, error) { + watchPaths, err := getWatchPaths(rootPaths) + if err != nil { + return nil, err + } + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, err + } + + for _, path := range watchPaths { + w.logger.WithFields(map[string]interface{}{"path": path}).Debug("watching path") + if err := watcher.Add(path); err != nil { + return nil, err + } + } + + return watcher, nil +} + +func (w *FileWatcher) readWatcher(ctx context.Context, watcher *fsnotify.Watcher) { + for evt := range watcher.Events { + removalMask := fsnotify.Remove | fsnotify.Rename + mask := fsnotify.Create | fsnotify.Write | removalMask + if (evt.Op & mask) != 0 { + w.logger.WithFields(map[string]interface{}{ + "event": evt.String(), + }).Debug("Registered file event.") + removed := "" + if (evt.Op & removalMask) != 0 { + removed = evt.Name + } + w.processWatcherUpdate(ctx, w.paths, removed) + } + } +} + +func (w *FileWatcher) processWatcherUpdate(ctx context.Context, paths []string, removed string) { + t0 := time.Now() + + loaded, err := initload.LoadPaths(paths, w.filter, w.bundleMode, nil, true, false, nil) + if err != nil { + w.onReload(ctx, nil, time.Since(t0), w.store, nil, err) + return + } + + removed = loader.CleanPath(removed) + + err = storage.Txn(ctx, w.store, storage.WriteParams, func(txn storage.Transaction) error { + if !w.bundleMode { + ids, err := w.store.ListPolicies(ctx, txn) + if err != nil { + return err + } + for _, id := range ids { + if id == removed { + if err := w.store.DeletePolicy(ctx, txn, id); err != nil { + return err + } + } else if _, exists := loaded.Files.Modules[id]; !exists { + // This branch get hit in two cases. + // 1. Another piece of code has access to the store and inserts + // a policy out-of-band. + // 2. In between FS notification and loader.Filtered() call above, a + // policy is removed from disk. + bs, err := w.store.GetPolicy(ctx, txn, id) + if err != nil { + return err + } + module, err := ast.ParseModule(id, string(bs)) + if err != nil { + return err + } + loaded.Files.Modules[id] = &loader.RegoFile{ + Name: id, + Raw: bs, + Parsed: module, + } + } + } + } + + // It's up to onReload to update the store with loaded content + w.onReload(ctx, txn, time.Since(t0), w.store, loaded, err) + return nil + }) + + if err != nil { + w.onReload(ctx, nil, time.Since(t0), w.store, nil, err) + } +} + +func getWatchPaths(rootPaths []string) ([]string, error) { + paths := []string{} + + for _, path := range rootPaths { + + _, path = loader.SplitPrefix(path) + result, err := loader.Paths(path, true) + if err != nil { + return nil, err + } + + paths = append(paths, loader.Dirs(result)...) + } + + return paths, nil +} diff --git a/filewatcher/filewatcher_test.go b/filewatcher/filewatcher_test.go new file mode 100644 index 0000000000..528aa47b7e --- /dev/null +++ b/filewatcher/filewatcher_test.go @@ -0,0 +1,39 @@ +// Copyright 2023 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package filewatcher + +import ( + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/open-policy-agent/opa/util/test" +) + +func TestWatchPaths(t *testing.T) { + + fs := map[string]string{ + "/foo/bar/baz.json": "true", + } + + expected := []string{ + ".", "/foo", "/foo/bar", + } + + test.WithTempFS(fs, func(rootDir string) { + paths, err := getWatchPaths([]string{"prefix:" + rootDir + "/foo"}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + result := []string{} + for _, p := range paths { + result = append(result, filepath.Clean(strings.TrimPrefix(p, rootDir))) + } + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Expected %q but got: %q", expected, result) + } + }) +} diff --git a/runtime/runtime.go b/runtime/runtime.go index 3942e7d2e2..2d1ffa3da4 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -21,15 +21,14 @@ import ( "syscall" "time" - "github.com/fsnotify/fsnotify" "github.com/gorilla/mux" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "go.opentelemetry.io/otel/propagation" "go.uber.org/automaxprocs/maxprocs" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/bundle" + "github.com/open-policy-agent/opa/filewatcher" "github.com/open-policy-agent/opa/internal/config" internal_tracing "github.com/open-policy-agent/opa/internal/distributedtracing" internal_logging "github.com/open-policy-agent/opa/internal/logging" @@ -696,89 +695,9 @@ func (rt *Runtime) decisionLogger(ctx context.Context, event *server.Info) error return plugin.Log(ctx, event) } -func (rt *Runtime) startWatcher(ctx context.Context, paths []string, onReload func(time.Duration, error)) error { - watcher, err := rt.getWatcher(paths) - if err != nil { - return err - } - go rt.readWatcher(ctx, watcher, paths, onReload) - return nil -} - -func (rt *Runtime) readWatcher(ctx context.Context, watcher *fsnotify.Watcher, paths []string, onReload func(time.Duration, error)) { - for evt := range watcher.Events { - removalMask := fsnotify.Remove | fsnotify.Rename - mask := fsnotify.Create | fsnotify.Write | removalMask - if (evt.Op & mask) != 0 { - rt.logger.WithFields(map[string]interface{}{ - "event": evt.String(), - }).Debug("Registered file event.") - t0 := time.Now() - removed := "" - if (evt.Op & removalMask) != 0 { - removed = evt.Name - } - err := rt.processWatcherUpdate(ctx, paths, removed) - onReload(time.Since(t0), err) - } - } -} - -func (rt *Runtime) processWatcherUpdate(ctx context.Context, paths []string, removed string) error { - loaded, err := initload.LoadPaths(paths, rt.Params.Filter, rt.Params.BundleMode, nil, true, false, nil) - if err != nil { - return err - } - - removed = loader.CleanPath(removed) - - return storage.Txn(ctx, rt.Store, storage.WriteParams, func(txn storage.Transaction) error { - if !rt.Params.BundleMode { - ids, err := rt.Store.ListPolicies(ctx, txn) - if err != nil { - return err - } - for _, id := range ids { - if id == removed { - if err := rt.Store.DeletePolicy(ctx, txn, id); err != nil { - return err - } - } else if _, exists := loaded.Files.Modules[id]; !exists { - // This branch get hit in two cases. - // 1. Another piece of code has access to the store and inserts - // a policy out-of-band. - // 2. In between FS notification and loader.Filtered() call above, a - // policy is removed from disk. - bs, err := rt.Store.GetPolicy(ctx, txn, id) - if err != nil { - return err - } - module, err := ast.ParseModule(id, string(bs)) - if err != nil { - return err - } - loaded.Files.Modules[id] = &loader.RegoFile{ - Name: id, - Raw: bs, - Parsed: module, - } - } - } - } - - _, err := initload.InsertAndCompile(ctx, initload.InsertAndCompileOptions{ - Store: rt.Store, - Txn: txn, - Files: loaded.Files, - Bundles: loaded.Bundles, - MaxErrors: -1, - }) - if err != nil { - return err - } - - return nil - }) +func (rt *Runtime) startWatcher(ctx context.Context, paths []string, onReload filewatcher.OnReload) error { + watcher := filewatcher.NewFileWatcher(paths, rt.Params.Filter, rt.Params.BundleMode, rt.Store, onReload, rt.logger) + return watcher.Start(ctx) } func (rt *Runtime) getBanner() string { @@ -827,34 +746,23 @@ func (rt *Runtime) waitPluginsReady(checkInterval, timeout time.Duration) error return util.WaitFunc(pluginsReady, checkInterval, timeout) } -func (rt *Runtime) onReloadLogger(d time.Duration, err error) { +func (rt *Runtime) onReloadLogger(ctx context.Context, txn storage.Transaction, d time.Duration, s storage.Store, l *initload.LoadPathsResult, err error) { + if err == nil { + _, err = initload.InsertAndCompile(ctx, initload.InsertAndCompileOptions{ + Store: s, + Txn: txn, + Files: l.Files, + Bundles: l.Bundles, + MaxErrors: -1, + }) + } + rt.logger.WithFields(map[string]interface{}{ "duration": d, "err": err, }).Info("Processed file watch event.") } -func (rt *Runtime) getWatcher(rootPaths []string) (*fsnotify.Watcher, error) { - watchPaths, err := getWatchPaths(rootPaths) - if err != nil { - return nil, err - } - - watcher, err := fsnotify.NewWatcher() - if err != nil { - return nil, err - } - - for _, path := range watchPaths { - rt.logger.WithFields(map[string]interface{}{"path": path}).Debug("watching path") - if err := watcher.Add(path); err != nil { - return nil, err - } - } - - return watcher, nil -} - func urlPathToConfigOverride(pathCount int, path string) ([]string, error) { uri, err := url.Parse(path) if err != nil { @@ -880,25 +788,18 @@ func errorLogger(logger logging.Logger) func(attrs map[string]interface{}, f str } } -func getWatchPaths(rootPaths []string) ([]string, error) { - paths := []string{} - - for _, path := range rootPaths { - - _, path = loader.SplitPrefix(path) - result, err := loader.Paths(path, true) - if err != nil { - return nil, err +func onReloadPrinter(output io.Writer) filewatcher.OnReload { + return func(ctx context.Context, txn storage.Transaction, d time.Duration, s storage.Store, l *initload.LoadPathsResult, err error) { + if err == nil { + _, err = initload.InsertAndCompile(ctx, initload.InsertAndCompileOptions{ + Store: s, + Txn: txn, + Files: l.Files, + Bundles: l.Bundles, + MaxErrors: -1, + }) } - paths = append(paths, loader.Dirs(result)...) - } - - return paths, nil -} - -func onReloadPrinter(output io.Writer) func(time.Duration, error) { - return func(d time.Duration, err error) { if err != nil { fmt.Fprintf(output, "\n# reload error (took %v): %v", d, err) } else { diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index f11e060987..3b647c4aa9 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -18,41 +18,17 @@ import ( "testing" "time" + "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/internal/report" + initload "github.com/open-policy-agent/opa/internal/runtime/init" "github.com/open-policy-agent/opa/logging" "github.com/open-policy-agent/opa/server" - - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/storage" "github.com/open-policy-agent/opa/util" "github.com/open-policy-agent/opa/util/test" ) -func TestWatchPaths(t *testing.T) { - - fs := map[string]string{ - "/foo/bar/baz.json": "true", - } - - expected := []string{ - ".", "/foo", "/foo/bar", - } - - test.WithTempFS(fs, func(rootDir string) { - paths, err := getWatchPaths([]string{"prefix:" + rootDir + "/foo"}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - result := []string{} - for _, p := range paths { - result = append(result, filepath.Clean(strings.TrimPrefix(p, rootDir))) - } - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %q but got: %q", expected, result) - } - }) -} - +// TODO: Refactor watch tests to be reusable. func TestRuntimeProcessWatchEvents(t *testing.T) { testRuntimeProcessWatchEvents(t, false) } @@ -181,7 +157,7 @@ func testRuntimeProcessWatchEventPolicyError(t *testing.T, asBundle bool) { ch := make(chan error) - testFunc := func(d time.Duration, err error) { + testFunc := func(ctx context.Context, txn storage.Transaction, d time.Duration, s storage.Store, l *initload.LoadPathsResult, err error) { ch <- err }