Skip to content

Commit

Permalink
ast: Rewrite input document when passed as an operand to a call
Browse files Browse the repository at this point in the history
Earlier if a reference to the root of the input document was passed as operand in a call, it would not be rewritten to a local variable and substituted in the call. Refs to non-root input/data document would be rewritten. This was happening because we were updating the value of the input root ref term to a variable and as a result the compiler stage that rewrites the body of dynamic terms ('rewriteDynamicTerms') would not rewrite the input root ref. These changes modify the 'rewriteLocalVars' stage so that the value of root ref term is updated based on whether its value is an explicitly declared variable. If it isn't, then it will be rewritten in a later stage.

Fixes open-policy-agent#2084

Signed-off-by: Ashutosh Narkar <anarkar4387@gmail.com>
  • Loading branch information
ashutosh-narkar committed Apr 2, 2020
1 parent 94cd44e commit 517440d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
16 changes: 15 additions & 1 deletion ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -2966,6 +2966,17 @@ func (s localDeclaredVars) Occurrence(x Var) varOccurrence {
return s.vars[len(s.vars)-1].occurrence[x]
}

// GlobalOccurrence returns a flag that indicates whether x has occurred in the
// global scope.
func (s localDeclaredVars) GlobalOccurrence(x Var) (varOccurrence, bool) {
for i := len(s.vars) - 1; i >= 0; i-- {
if occ, ok := s.vars[i].occurrence[x]; ok {
return occ, true
}
}
return newVar, false
}

// rewriteLocalVars rewrites bodies to remove assignment/declaration
// expressions. For example:
//
Expand Down Expand Up @@ -3148,9 +3159,12 @@ func rewriteDeclaredVarsInTerm(g *localVarGenerator, stack *localDeclaredVars, t
}
case Ref:
if RootDocumentRefs.Contains(term) {
if gv, ok := stack.Declared(v[0].Value.(Var)); ok {
x := v[0].Value.(Var)
if occ, ok := stack.GlobalOccurrence(x); ok && occ != seenVar {
gv, _ := stack.Declared(x)
term.Value = gv
}

return true, errs
}
return false, errs
Expand Down
28 changes: 28 additions & 0 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,32 @@ func TestRewriteDeclaredVars(t *testing.T) {
}
`,
},
{
note: "rewrite call with root document ref as arg",
module: `
package test
p {
f(input, "bar")
}
f(x,y) {
x[y]
}
`,
exp: `
package test
p = true {
__local2__ = input;
data.test.f(__local2__, "bar")
}
f(__local0__, __local1__) = true {
__local0__[__local1__]
}
`,
},
{
note: "redeclare err",
module: `
Expand Down Expand Up @@ -2198,6 +2224,8 @@ func TestCompilerRewriteDynamicTerms(t *testing.T) {
{`eq_with { [str] = [1] with input as 1 }`, `__local0__ = data.test.str with input as 1; [__local0__] = [1] with input as 1`},
{`term_with { [[str]] with input as 1 }`, `__local0__ = data.test.str with input as 1; [[__local0__]] with input as 1`},
{`call_with { count(str) with input as 1 }`, `__local0__ = data.test.str with input as 1; count(__local0__) with input as 1`},
{`call_func { f(input, "foo") } f(x,y) { x[y] }`, `__local2__ = input; data.test.f(__local2__, "foo")`},
{`call_func2 { f(input.foo, "foo") } f(x,y) { x[y] }`, `__local2__ = input.foo; data.test.f(__local2__, "foo")`},
}

for _, tc := range tests {
Expand Down

0 comments on commit 517440d

Please sign in to comment.