Skip to content

Commit

Permalink
debug: Configurable rego-options on debugger (#7053)
Browse files Browse the repository at this point in the history
* debug: Configurable rego-options on debugger

Adding `RegoOption` launch option to debugger for setting custom rego options.

Fixes: #7045
Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling authored Sep 23, 2024
1 parent e959bce commit 09c1bdf
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 23 deletions.
60 changes: 45 additions & 15 deletions debug/debugger.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (
type Debugger interface {
// LaunchEval starts a new eval debug session with the given LaunchEvalProperties.
// The returned session is in a stopped state, and must be resumed to start execution.
LaunchEval(ctx context.Context, props LaunchEvalProperties) (Session, error)
LaunchEval(ctx context.Context, props LaunchEvalProperties, opts ...LaunchOption) (Session, error)
}

type debugger struct {
Expand Down Expand Up @@ -207,14 +207,38 @@ type LaunchTestProperties struct {
}

type LaunchProperties struct {
BundlePaths []string
DataPaths []string
StopOnResult bool
StopOnEntry bool
StopOnFail bool
EnablePrint bool
SkipOps []topdown.Op
RuleIndexing bool
BundlePaths []string
DataPaths []string
StopOnResult bool
StopOnEntry bool
StopOnFail bool
EnablePrint bool
SkipOps []topdown.Op
StrictBuiltinErrors bool
RuleIndexing bool
}

type LaunchOption func(options *launchOptions)

type launchOptions struct {
regoOptions []func(*rego.Rego)
}

func newLaunchOptions(opts []LaunchOption) *launchOptions {
options := &launchOptions{}
for _, opt := range opts {
opt(options)
}
return options
}

// RegoOption adds a rego option to the internal Rego instance.
// Options may be overridden by the debugger, and it is recommended to
// use LaunchEvalProperties for commonly used options.
func RegoOption(opt func(*rego.Rego)) LaunchOption {
return func(options *launchOptions) {
options.regoOptions = append(options.regoOptions, opt)
}
}

func (lp LaunchProperties) String() string {
Expand All @@ -225,18 +249,24 @@ func (lp LaunchProperties) String() string {
return string(b)
}

func (d *debugger) LaunchEval(ctx context.Context, props LaunchEvalProperties) (Session, error) {
func (d *debugger) LaunchEval(ctx context.Context, props LaunchEvalProperties, opts ...LaunchOption) (Session, error) {
options := newLaunchOptions(opts)

store := inmem.New()
txn, err := store.NewTransaction(ctx, storage.TransactionParams{Write: true})
if err != nil {
return nil, fmt.Errorf("failed to create store transaction: %v", err)
}

regoArgs := []func(*rego.Rego){
rego.Query(props.Query),
rego.Store(store),
rego.Transaction(txn),
}
regoArgs := make([]func(*rego.Rego), 0, 4)

// We apply all user options first, so the debugger can make overrides if necessary.
regoArgs = append(regoArgs, options.regoOptions...)

regoArgs = append(regoArgs, rego.Query(props.Query))
regoArgs = append(regoArgs, rego.Store(store))
regoArgs = append(regoArgs, rego.Transaction(txn))
regoArgs = append(regoArgs, rego.StrictBuiltinErrors(props.StrictBuiltinErrors))

if props.SkipOps == nil {
props.SkipOps = []topdown.Op{topdown.IndexOp, topdown.RedoOp, topdown.SaveOp, topdown.UnifyOp}
Expand Down
77 changes: 69 additions & 8 deletions debug/debugger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package debug

import (
"context"
"encoding/json"
"fmt"
"path"
"reflect"
Expand All @@ -20,6 +21,8 @@ import (
"github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/storage/inmem"
"github.com/open-policy-agent/opa/topdown"
"github.com/open-policy-agent/opa/topdown/builtins"
"github.com/open-policy-agent/opa/types"
"github.com/open-policy-agent/opa/util/test"
)

Expand Down Expand Up @@ -2018,14 +2021,6 @@ func newTestStack(events ...*topdown.Event) *testStack {
}
}

//func (ts *testStack) done() bool {
// return ts.index >= len(ts.events)
//}

//func (ts *testStack) onLastEvent() bool {
// return ts.index == len(ts.events)-1
//}

func (ts *testStack) Enabled() bool {
return true
}
Expand Down Expand Up @@ -2068,3 +2063,69 @@ func (ts *testStack) Close() error {
ts.closed = true
return nil
}

func TestDebuggerCustomBuiltIn(t *testing.T) {
ctx := context.Background()

decl := &rego.Function{
Name: "my.builtin",
Description: "My built-in",
Decl: types.NewFunction(
types.Args(types.S, types.S),
types.S,
),
}

fn := func(_ rego.BuiltinContext, a, b *ast.Term) (*ast.Term, error) {
aStr, err := builtins.StringOperand(a.Value, 1)
if err != nil {
return nil, err
}

bStr, err := builtins.StringOperand(b.Value, 2)
if err != nil {
return nil, err
}

return ast.StringTerm(fmt.Sprintf("%s+%s", aStr, bStr)), nil
}

props := LaunchEvalProperties{
Query: `x := my.builtin("hello", "world")`,
}

exp := `[{"expressions":[{"value":true,"text":"x := my.builtin(\"hello\", \"world\")","location":{"row":1,"col":1}}],"bindings":{"x":"\"hello\"+\"world\""}}]`

eh := newTestEventHandler()

d := NewDebugger(SetEventHandler(eh.HandleEvent))

s, err := d.LaunchEval(ctx, props, RegoOption(rego.Function2(decl, fn)))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if err := s.ResumeAll(); err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// wait for result
if e := eh.WaitFor(ctx, TerminatedEventType); e == nil {
t.Fatal("Expected terminated event")
}

ts, err := s.Threads()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

res := ts[0].(*thread).stack.Result()
bs, err := json.Marshal(res)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
actual := string(bs)
if actual != exp {
t.Fatalf("Expected:\n\n%v\n\nbut got:\n\n%v", exp, actual)
}
}
3 changes: 3 additions & 0 deletions debug/thread.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ type eventHandler func(t *thread, stackIndex int, e *topdown.Event, s threadStat

type ThreadID int

// Thread represents a single thread of execution.
type Thread interface {
// ID returns the unique identifier for the thread.
ID() ThreadID
// Name returns the human-readable name of the thread.
Name() string
}

Expand Down

0 comments on commit 09c1bdf

Please sign in to comment.