Skip to content

Commit

Permalink
ast/compile: respect unsafeBuiltinMap for 'with' replacements
Browse files Browse the repository at this point in the history
The changes are necessary for both the Compiler and the QueryCompiler. Tests
have been added to ensure that the code path through the rego package has also
been fixed.

Fixes CVE-2022-36085.

Signed-off-by: Stephan Renatus <stephan.renatus@gmail.com>
  • Loading branch information
srenatus committed Sep 7, 2022
1 parent b78756f commit 3e8c754
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 31 deletions.
58 changes: 33 additions & 25 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -2196,7 +2196,7 @@ func (c *Compiler) rewriteWithModifiers() {
if !ok {
return x, nil
}
body, err := rewriteWithModifiersInBody(c, f, body)
body, err := rewriteWithModifiersInBody(c, c.unsafeBuiltinsMap, f, body)
if err != nil {
c.err(err)
}
Expand Down Expand Up @@ -2475,19 +2475,20 @@ func (qc *queryCompiler) checkTypes(_ *QueryContext, body Body) (Body, error) {
}

func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, error) {
var unsafe map[string]struct{}
if qc.unsafeBuiltins != nil {
unsafe = qc.unsafeBuiltins
} else {
unsafe = qc.compiler.unsafeBuiltinsMap
}
errs := checkUnsafeBuiltins(unsafe, body)
errs := checkUnsafeBuiltins(qc.unsafeBuiltinsMap(), body)
if len(errs) > 0 {
return nil, errs
}
return body, nil
}

func (qc *queryCompiler) unsafeBuiltinsMap() map[string]struct{} {
if qc.unsafeBuiltins != nil {
return qc.unsafeBuiltins
}
return qc.compiler.unsafeBuiltinsMap
}

func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Body, error) {
errs := checkDeprecatedBuiltins(qc.compiler.deprecatedBuiltinsMap, body, qc.compiler.strict)
if len(errs) > 0 {
Expand All @@ -2498,7 +2499,7 @@ func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Bo

func (qc *queryCompiler) rewriteWithModifiers(_ *QueryContext, body Body) (Body, error) {
f := newEqualityFactory(newLocalVarGenerator("q", body))
body, err := rewriteWithModifiersInBody(qc.compiler, f, body)
body, err := rewriteWithModifiersInBody(qc.compiler, qc.unsafeBuiltinsMap(), f, body)
if err != nil {
return nil, Errors{err}
}
Expand Down Expand Up @@ -4779,10 +4780,10 @@ func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, o
// rewriteWithModifiersInBody will rewrite the body so that with modifiers do
// not contain terms that require evaluation as values. If this function
// encounters an invalid with modifier target then it will raise an error.
func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Body, *Error) {
func rewriteWithModifiersInBody(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, body Body) (Body, *Error) {
var result Body
for i := range body {
exprs, err := rewriteWithModifier(c, f, body[i])
exprs, err := rewriteWithModifier(c, unsafeBuiltinsMap, f, body[i])
if err != nil {
return nil, err
}
Expand All @@ -4797,11 +4798,11 @@ func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Bod
return result, nil
}

func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, *Error) {
func rewriteWithModifier(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, expr *Expr) ([]*Expr, *Error) {

var result []*Expr
for i := range expr.With {
eval, err := validateWith(c, expr, i)
eval, err := validateWith(c, unsafeBuiltinsMap, expr, i)
if err != nil {
return nil, err
}
Expand All @@ -4816,7 +4817,7 @@ func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr,
return append(result, expr), nil
}

func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) {
func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr, i int) (bool, *Error) {
target, value := expr.With[i].Target, expr.With[i].Value

// Ensure that values that are built-ins are rewritten to Ref (not Var)
Expand All @@ -4825,6 +4826,10 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) {
value.Value = Ref([]*Term{NewTerm(v)})
}
}
isBuiltinRefOrVar, err := isBuiltinRefOrVar(c.builtins, unsafeBuiltinsMap, target)
if err != nil {
return false, err
}

switch {
case isDataRef(target):
Expand All @@ -4848,15 +4853,15 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) {
if child := node.Child(ref[len(ref)-1].Value); child != nil {
for _, v := range child.Values {
if len(v.(*Rule).Head.Args) > 0 {
if validateWithFunctionValue(c.builtins, c.RuleTree, value) {
return false, nil
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
return false, err // may be nil
}
}
}
}
}
case isInputRef(target): // ok, valid
case isBuiltinRefOrVar(c.builtins, target):
case isBuiltinRefOrVar:

// NOTE(sr): first we ensure that parsed Var builtins (`count`, `concat`, etc)
// are rewritten to their proper Ref convention
Expand All @@ -4870,8 +4875,8 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) {
return false, err
}

if validateWithFunctionValue(c.builtins, c.RuleTree, value) {
return false, nil
if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok {
return false, err // may be nil
}
default:
return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument)
Expand Down Expand Up @@ -4900,13 +4905,13 @@ func validateWithBuiltinTarget(bi *Builtin, target Ref, loc *location.Location)
return nil
}

func validateWithFunctionValue(bs map[string]*Builtin, ruleTree *TreeNode, value *Term) bool {
func validateWithFunctionValue(bs map[string]*Builtin, unsafeMap map[string]struct{}, ruleTree *TreeNode, value *Term) (bool, *Error) {
if v, ok := value.Value.(Ref); ok {
if ruleTree.Find(v) != nil { // ref exists in rule tree
return true
return true, nil
}
}
return isBuiltinRefOrVar(bs, value)
return isBuiltinRefOrVar(bs, unsafeMap, value)
}

func isInputRef(term *Term) bool {
Expand All @@ -4927,13 +4932,16 @@ func isDataRef(term *Term) bool {
return false
}

func isBuiltinRefOrVar(bs map[string]*Builtin, term *Term) bool {
func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]struct{}, term *Term) (bool, *Error) {
switch v := term.Value.(type) {
case Ref, Var:
if _, ok := unsafeBuiltinsMap[v.String()]; ok {
return false, NewError(CompileErr, term.Location, "with keyword replacing built-in function: target must not be unsafe: %q", v)
}
_, ok := bs[v.String()]
return ok
return ok, nil
}
return false
return false, nil
}

func isVirtual(node *TreeNode, ref Ref) bool {
Expand Down
86 changes: 80 additions & 6 deletions ast/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3970,6 +3970,7 @@ func TestCompilerRewriteWithValue(t *testing.T) {
tests := []struct {
note string
input string
opts func(*Compiler) *Compiler
expected string
expectedRule *Rule
wantErr error
Expand Down Expand Up @@ -4075,6 +4076,26 @@ func TestCompilerRewriteWithValue(t *testing.T) {
return r
}(),
},
{
note: "built-in function: replaced by another built-in that's marked unsafe",
input: `
q := is_object({"url": "https://httpbin.org", "method": "GET"})
p { q with is_object as http.send }
`,
opts: func(c *Compiler) *Compiler { return c.WithUnsafeBuiltins(map[string]struct{}{"http.send": {}}) },
wantErr: fmt.Errorf("rego_compile_error: with keyword replacing built-in function: target must not be unsafe: \"http.send\""),
},
{
note: "non-built-in function: replaced by another built-in that's marked unsafe",
input: `
r(_) = {}
q := r({"url": "https://httpbin.org", "method": "GET"})
p {
q with r as http.send
}`,
opts: func(c *Compiler) *Compiler { return c.WithUnsafeBuiltins(map[string]struct{}{"http.send": {}}) },
wantErr: fmt.Errorf("rego_compile_error: with keyword replacing built-in function: target must not be unsafe: \"http.send\""),
},
{
note: "built-in function: valid, arity 1, non-compound name",
input: `
Expand All @@ -4092,6 +4113,9 @@ func TestCompilerRewriteWithValue(t *testing.T) {
for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
c := NewCompiler()
if tc.opts != nil {
c = tc.opts(c)
}
module := fixture + tc.input
c.Modules["test"] = MustParseModule(module)
compileStages(c, c.rewriteWithModifiers)
Expand Down Expand Up @@ -6597,13 +6621,63 @@ func TestQueryCompilerWithStageAfterWithMetrics(t *testing.T) {
}

func TestQueryCompilerWithUnsafeBuiltins(t *testing.T) {
c := NewCompiler().WithUnsafeBuiltins(map[string]struct{}{
"count": {},
})
tests := []struct {
note string
query string
compiler *Compiler
opts func(QueryCompiler) QueryCompiler
err string
}{
{
note: "builtin unsafe via compiler",
query: "count([])",
compiler: NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}}),
err: "unsafe built-in function calls in expression: count",
},
{
note: "builtin unsafe via query compiler",
query: "count([])",
compiler: NewCompiler(),
opts: func(qc QueryCompiler) QueryCompiler {
return qc.WithUnsafeBuiltins(map[string]struct{}{"count": {}})
},
err: "unsafe built-in function calls in expression: count",
},
{
note: "builtin unsafe via compiler, 'with' mocking",
query: "is_array([]) with is_array as count",
compiler: NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}}),
err: `with keyword replacing built-in function: target must not be unsafe: "count"`,
},
{
note: "builtin unsafe via query compiler, 'with' mocking",
query: "is_array([]) with is_array as count",
compiler: NewCompiler(),
opts: func(qc QueryCompiler) QueryCompiler {
return qc.WithUnsafeBuiltins(map[string]struct{}{"count": {}})
},
err: `with keyword replacing built-in function: target must not be unsafe: "count"`,
},
}

_, err := c.QueryCompiler().WithUnsafeBuiltins(map[string]struct{}{}).Compile(MustParseBody("count([])"))
if err != nil {
t.Fatal(err)
for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
qc := tc.compiler.QueryCompiler()
if tc.opts != nil {
qc = tc.opts(qc)
}
_, err := qc.Compile(MustParseBody(tc.query))
var errs Errors
if !errors.As(err, &errs) {
t.Fatalf("expected error type %T, got %v %[2]T", errs, err)
}
if exp, act := 1, len(errs); exp != act {
t.Fatalf("expected %d error(s), got %d", exp, act)
}
if exp, act := tc.err, errs[0].Message; exp != act {
t.Errorf("expected message %q, got %q", exp, act)
}
})
}
}

Expand Down
51 changes: 51 additions & 0 deletions rego/rego_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,7 @@ func TestUnsafeBuiltins(t *testing.T) {
ctx := context.Background()

unsafeCountExpr := "unsafe built-in function calls in expression: count"
unsafeCountExprWith := `with keyword replacing built-in function: target must not be unsafe: "count"`

t.Run("unsafe query", func(t *testing.T) {
r := New(
Expand All @@ -1447,6 +1448,16 @@ func TestUnsafeBuiltins(t *testing.T) {
}
})

t.Run("unsafe query, 'with' replacement", func(t *testing.T) {
r := New(
Query(`is_array([1, 2, 3]) with is_array as count`),
UnsafeBuiltins(map[string]struct{}{"count": {}}),
)
if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) {
t.Fatalf("Expected unsafe built-in error but got %v", err)
}
})

t.Run("unsafe module", func(t *testing.T) {
r := New(
Query(`data.pkg.deny`),
Expand All @@ -1462,6 +1473,36 @@ func TestUnsafeBuiltins(t *testing.T) {
}
})

t.Run("unsafe module, 'with' replacement in query", func(t *testing.T) {
r := New(
Query(`data.pkg.deny with is_array as count`),
Module("pkg.rego", `package pkg
deny {
is_array(input.requests) > 10
}
`),
UnsafeBuiltins(map[string]struct{}{"count": {}}),
)
if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) {
t.Fatalf("Expected unsafe built-in error but got %v", err)
}
})

t.Run("unsafe module, 'with' replacement in module", func(t *testing.T) {
r := New(
Query(`data.pkg.deny`),
Module("pkg.rego", `package pkg
deny {
is_array(input.requests) > 10 with is_array as count
}
`),
UnsafeBuiltins(map[string]struct{}{"count": {}}),
)
if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) {
t.Fatalf("Expected unsafe built-in error but got %v", err)
}
})

t.Run("inherit in query", func(t *testing.T) {
r := New(
Compiler(ast.NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}})),
Expand All @@ -1472,6 +1513,16 @@ func TestUnsafeBuiltins(t *testing.T) {
}
})

t.Run("inherit in query, 'with' replacement", func(t *testing.T) {
r := New(
Compiler(ast.NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}})),
Query("is_array([]) with is_array as count"),
)
if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) {
t.Fatalf("Expected unsafe built-in error but got %v", err)
}
})

t.Run("override/disable in query", func(t *testing.T) {
r := New(
Compiler(ast.NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}})),
Expand Down

0 comments on commit 3e8c754

Please sign in to comment.