diff --git a/ast/compile_test.go b/ast/compile_test.go index dce2baf216..223f7eac7b 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -3169,6 +3169,36 @@ func TestCompilerRewritePrintCalls(t *testing.T) { f(__local0__) = __local2__ { true; __local2__ = {1 | __local0__[x]; __local3__ = {__local1__ | __local1__ = x}; internal.print([__local3__])} } `, }, + { + note: "print call of var in head key", + module: `package test + f(_) = [1, 2, 3] + p[x] { [_, x, _] := f(true); print(x) }`, + exp: `package test + f(__local0__) = [1, 2, 3] { true } + p[__local2__] { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) } + `, + }, + { + note: "print call of var in head value", + module: `package test + f(_) = [1, 2, 3] + p = x { [_, x, _] := f(true); print(x) }`, + exp: `package test + f(__local0__) = [1, 2, 3] { true } + p = __local2__ { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) } + `, + }, + { + note: "print call of vars in head key and value", + module: `package test + f(_) = [1, 2, 3] + p[x] = y { [_, x, y] := f(true); print(x) }`, + exp: `package test + f(__local0__) = [1, 2, 3] { true } + p[__local2__] = __local3__ { data.test.f(true, __local5__); [__local1__, __local2__, __local3__] = __local5__; __local6__ = {__local4__ | __local4__ = __local2__}; internal.print([__local6__]) } + `, + }, } for _, tc := range cases { diff --git a/ast/unify.go b/ast/unify.go index 80e8ae31d5..60244974a9 100644 --- a/ast/unify.go +++ b/ast/unify.go @@ -9,10 +9,7 @@ func isRefSafe(ref Ref, safe VarSet) bool { case Var: return safe.Contains(head) case Call: - vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams) - vis.Walk(head) - unsafe := vis.Vars().Diff(safe) - return len(unsafe) == 0 + return isCallSafe(head, safe) default: for v := range ref[0].Vars() { if !safe.Contains(v) { @@ -23,6 +20,13 @@ func isRefSafe(ref Ref, safe VarSet) bool { } } +func isCallSafe(call Call, safe VarSet) bool { + vis := NewVarVisitor().WithParams(SafetyCheckVisitorParams) + vis.Walk(call) + unsafe := vis.Vars().Diff(safe) + return len(unsafe) == 0 +} + // Unify returns a set of variables that will be unified when the equality expression defined by // terms a and b is evaluated. The unifier assumes that variables in the VarSet safe are already // unified. @@ -67,6 +71,10 @@ func (u *unifier) unify(a *Term, b *Term) { if isRefSafe(b, u.safe) { u.markSafe(a) } + case Call: + if isCallSafe(b, u.safe) { + u.markSafe(a) + } default: u.markSafe(a) } @@ -81,6 +89,16 @@ func (u *unifier) unify(a *Term, b *Term) { } } + case Call: + if isCallSafe(a, u.safe) { + switch b := b.Value.(type) { + case Var: + u.markSafe(b) + case *Array, Object: + u.markAllSafe(b) + } + } + case *ArrayComprehension: switch b := b.Value.(type) { case Var: @@ -105,8 +123,16 @@ func (u *unifier) unify(a *Term, b *Term) { switch b := b.Value.(type) { case Var: u.unifyAll(b, a) - case Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension: + case *ArrayComprehension, *ObjectComprehension, *SetComprehension: u.markAllSafe(a) + case Ref: + if isRefSafe(b, u.safe) { + u.markAllSafe(a) + } + case Call: + if isCallSafe(b, u.safe) { + u.markAllSafe(a) + } case *Array: if a.Len() == b.Len() { for i := 0; i < a.Len(); i++ { @@ -120,7 +146,13 @@ func (u *unifier) unify(a *Term, b *Term) { case Var: u.unifyAll(b, a) case Ref: - u.markAllSafe(a) + if isRefSafe(b, u.safe) { + u.markAllSafe(a) + } + case Call: + if isCallSafe(b, u.safe) { + u.markAllSafe(a) + } case *object: if a.Len() == b.Len() { _ = a.Iter(func(k, v *Term) error { diff --git a/ast/unify_test.go b/ast/unify_test.go index 0942c2c6ea..9ea34eb6d7 100644 --- a/ast/unify_test.go +++ b/ast/unify_test.go @@ -4,7 +4,10 @@ package ast -import "testing" +import ( + "fmt" + "testing" +) func TestUnify(t *testing.T) { @@ -28,10 +31,16 @@ func TestUnify(t *testing.T) { {"object/var", `{"x": 1, "y": x} = y`, "[x]", "[y]"}, {"object/var (reversed)", `y = {"x": 1, "y": x}`, "[x]", "[y]"}, {"object/var-2", `{"x": 1, "y": x} = y`, "[y]", "[x]"}, - {"object/var-3", `{"x": 1, "y": x} = y`, "[]", "[]"}, {"object/uneven", `{"x": x, "y": 1} = {"x": y}`, "[]", "[]"}, {"object/uneven", `{"x": x, "y": 1} = {"x": y}`, "[x]", "[]"}, - {"call", "x = f(y)[z]", "[y]", "[x]"}, + {"var/call-ref", "x = f(y)[z]", "[y]", "[x]"}, + {"var/call-ref (reversed)", "f(y)[z] = x", "[y]", "[x]"}, + {"var/call", "x = f(z)", "[z]", "[x]"}, + {"var/call (reversed)", "f(z) = x", "[z]", "[x]"}, + {"array/call", "[x, y] = f(z)", "[z]", "[x,y]"}, + {"array/call (reversed)", "f(z) = [x, y]", "[z]", "[x,y]"}, + {"object/call", `{"a": x} = f(z)`, "[z]", "[x]"}, + {"object/call (reversed)", `f(z) = {"a": x}`, "[z]", "[x]"}, // transitive cases {"trans/redundant", "[x, x] = [x, 0]", "[]", "[x]"}, @@ -43,10 +52,44 @@ func TestUnify(t *testing.T) { {"trans/redundant-nested", "[x, z, z] = [1, [y, x], [2, 1]]", "[]", "[x, y, z]"}, {"trans/bidirectional", "[x, z, y] = [[z,y], [1,y], 2]", "[]", "[x, y, z]"}, {"trans/occurs", "[x, z, y] = [[y,z], [y, 1], [2, x]]", "[]", "[]"}, + + // unsafe refs + {note: "array/ref", expr: "[1,2,x] = a[_]"}, + {note: "array/ref (reversed)", expr: "a[_] = [1,2,x]"}, + {note: "object/ref", expr: `{"x": x} = a[_]`}, + {note: "object/ref (reversed)", expr: `a[_] = {"x": x}`}, + {note: "var/call-ref", expr: "x = f(y)[z]"}, + {note: "var/call-ref (reversed)", expr: "f(y)[z] = x"}, + + // unsafe vars + {note: "array/var", expr: "[1,2,x] = y"}, + {note: "array/var (reversed)", expr: "y = [1,2,x]"}, + {note: "object/var", expr: `{"x": 1, "y": x} = y`}, + {note: "object/var (reversed)", expr: `y = {"x": 1, "y": x}`}, + {note: "var/call", expr: "x = f(z)"}, + {note: "var/call (reversed)", expr: "f(z) = x"}, + + // unsafe call args + {note: "var/call-2", expr: "x = f(z)", safe: "[x]"}, + {note: "var/call-2 (reversed)", expr: "f(z) = x", safe: "[x]"}, + {note: "array/call", expr: "[x, y] = f(z)", safe: "[x,y]"}, + {note: "array/call (reversed)", expr: "f(z) = [x, y]", safe: "[x,y]"}, + {note: "object/call", expr: `{"a": x} = f(z)`, safe: "[x]"}, + {note: "object/call (reversed)", expr: `f(z) = {"a": x}`, safe: "[x]"}, + + // partial cases + {note: "trans/ref", expr: "[x, y, [x, y, i]] = [1, a[i], z]", safe: "[a]", expected: "[x, y]"}, + {note: "trans/ref", expr: "[x, y, [x, y, i]] = [1, a[i], z]", expected: "[x]"}, } - for i, tc := range tests { - t.Run(tc.note, func(t *testing.T) { + for _, tc := range tests { + if tc.expected == "" { + tc.expected = "[]" + } + if tc.safe == "" { + tc.safe = "[]" + } + t.Run(fmt.Sprintf("%s/%s/%s", tc.note, tc.safe, tc.expected), func(t *testing.T) { expr := MustParseBody(tc.expr)[0] safe := VarSet{} @@ -74,7 +117,7 @@ func TestUnify(t *testing.T) { missing := expected.Diff(result) extra := result.Diff(expected) if len(missing) != 0 || len(extra) != 0 { - t.Fatalf("%s (%d): Missing vars: %v, extra vars: %v", tc.note, i, missing, extra) + t.Fatalf("missing vars: %v, extra vars: %v", missing, extra) } }) }