Skip to content

Commit

Permalink
ast: Don't rewrite ref to function in with stmt
Browse files Browse the repository at this point in the history
Fixing issue where policies containing `with` replacing unknown functions can't be inspected.

Fixes: open-policy-agent#6812
Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling committed Jun 13, 2024
1 parent b463d30 commit 69e30e9
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 51 deletions.
22 changes: 17 additions & 5 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -5497,23 +5497,23 @@ func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr
switch {
case isDataRef(target):
ref := target.Value.(Ref)
node := c.RuleTree
targetNode := c.RuleTree
for i := 0; i < len(ref)-1; i++ {
child := node.Child(ref[i].Value)
child := targetNode.Child(ref[i].Value)
if child == nil {
break
} else if len(child.Values) > 0 {
return false, NewError(CompileErr, target.Loc(), "with keyword cannot partially replace virtual document(s)")
}
node = child
targetNode = child
}

if node != nil {
if targetNode != nil {
// NOTE(sr): at this point in the compiler stages, we don't have a fully-populated
// TypeEnv yet -- so we have to make do with this check to see if the replacement
// target is a function. It's probably wrong for arity-0 functions, but those are
// and edge case anyways.
if child := node.Child(ref[len(ref)-1].Value); child != nil {
if child := targetNode.Child(ref[len(ref)-1].Value); child != nil {
for _, v := range child.Values {
if len(v.(*Rule).Head.Args) > 0 {
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
Expand All @@ -5523,6 +5523,18 @@ func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr
}
}
}

// If the with-value is a ref to a function, but not a call, we can't rewrite it
if r, ok := value.Value.(Ref); ok {
// TODO: check that target ref doesn't exist?
if valueNode := c.RuleTree.Find(r); valueNode != nil {
for _, v := range valueNode.Values {
if len(v.(*Rule).Head.Args) > 0 {
return false, nil
}
}
}
}
case isInputRef(target): // ok, valid
case isBuiltinRefOrVar:

Expand Down
180 changes: 134 additions & 46 deletions cmd/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1028,47 +1028,25 @@ p contains 2 if {
}
}

func TestCallToUnknownBuiltInFunction(t *testing.T) {
files := [][2]string{
{"/policy.rego", `package test
p {
foo.bar(42)
contains("foo", "o")
}
`},
}

buf := archive.MustWriteTarGz(files)

test.WithTempFS(nil, func(rootDir string) {
bundleFile := filepath.Join(rootDir, "bundle.tar.gz")

bf, err := os.Create(bundleFile)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

_, err = bf.Write(buf.Bytes())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

var out bytes.Buffer
params := newInspectCommandParams()
err = params.outputFormat.Set(evalJSONOutput)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}

err = doInspect(params, bundleFile, &out)
if err != nil {
t.Fatalf("Unexpected error %v", err)
}

bs := out.Bytes()
output := strings.TrimSpace(string(bs))
// Note: unknown foo.bar() built-in doesn't appear in the output, but also didn't cause an error.
expected := strings.TrimSpace(`{
func TestUnknownRefs(t *testing.T) {
tests := []struct {
note string
files [][2]string
expected string
}{
{
note: "unknown built-in func call",
files: [][2]string{
{
"/policy.rego", `package test
p {
foo.bar(42)
contains("foo", "o")
}`,
},
},
// Note: unknown foo.bar() built-in doesn't appear in the output, but also didn't cause an error.
expected: `{
"manifest": {
"revision": "",
"roots": [
Expand Down Expand Up @@ -1102,12 +1080,122 @@ func TestCallToUnknownBuiltInFunction(t *testing.T) {
}
]
}
}`)
}`,
},
{
// Happy path
note: "ref replaced inside with stmt",
files: [][2]string{
{"/policy.rego", `package test
import rego.v1
if output != expected {
t.Fatalf("Unexpected output. Expected:\n\n%s\n\nGot:\n\n%s", expected, output)
}
})
foo.bar(_) := false
p if {
foo.bar(42)
}
mock(_) := true
test_p if {
p with data.test.foo.bar as mock
}`},
},
expected: `{
"manifest": {
"revision": "",
"roots": [
""
]
},
"signatures_config": {},
"namespaces": {
"data.test": [
"/policy.rego"
]
},
"capabilities": {
"features": [
"rego_v1_import"
]
}
}`,
},
{
note: "unknown ref replaced inside with stmt",
files: [][2]string{
{"/policy.rego", `package test
import rego.v1
p if {
data.foo.bar(42)
}
mock(_) := true
test_p if {
p with data.foo.bar as mock
}`},
},
expected: `{
"manifest": {
"revision": "",
"roots": [
""
]
},
"signatures_config": {},
"namespaces": {
"data.test": [
"/policy.rego"
]
},
"capabilities": {
"features": [
"rego_v1_import"
]
}
}`,
},
}

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
buf := archive.MustWriteTarGz(tc.files)

test.WithTempFS(nil, func(rootDir string) {
bundleFile := filepath.Join(rootDir, "bundle.tar.gz")

bf, err := os.Create(bundleFile)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

_, err = bf.Write(buf.Bytes())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

var out bytes.Buffer
params := newInspectCommandParams()
err = params.outputFormat.Set(evalJSONOutput)
if err != nil {
t.Fatalf("Unexpected error: %s", err)
}

err = doInspect(params, bundleFile, &out)
if err != nil {
t.Fatalf("Unexpected error %v", err)
}

bs := out.Bytes()
output := strings.TrimSpace(string(bs))
if output != tc.expected {
t.Fatalf("Unexpected output. Expected:\n\n%s\n\nGot:\n\n%s", tc.expected, output)
}
})
})
}
}

func TestCallToUnknownRegoFunction(t *testing.T) {
Expand Down

0 comments on commit 69e30e9

Please sign in to comment.