From 3e8c754ed007b22393cf65e48751ad9f6457fee8 Mon Sep 17 00:00:00 2001 From: Stephan Renatus Date: Mon, 5 Sep 2022 09:54:45 +0200 Subject: [PATCH] ast/compile: respect unsafeBuiltinMap for 'with' replacements The changes are necessary for both the Compiler and the QueryCompiler. Tests have been added to ensure that the code path through the rego package has also been fixed. Fixes CVE-2022-36085. Signed-off-by: Stephan Renatus --- ast/compile.go | 58 +++++++++++++++++------------- ast/compile_test.go | 86 +++++++++++++++++++++++++++++++++++++++++---- rego/rego_test.go | 51 +++++++++++++++++++++++++++ 3 files changed, 164 insertions(+), 31 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index 3697229dda..d2841143c4 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -2196,7 +2196,7 @@ func (c *Compiler) rewriteWithModifiers() { if !ok { return x, nil } - body, err := rewriteWithModifiersInBody(c, f, body) + body, err := rewriteWithModifiersInBody(c, c.unsafeBuiltinsMap, f, body) if err != nil { c.err(err) } @@ -2475,19 +2475,20 @@ func (qc *queryCompiler) checkTypes(_ *QueryContext, body Body) (Body, error) { } func (qc *queryCompiler) checkUnsafeBuiltins(_ *QueryContext, body Body) (Body, error) { - var unsafe map[string]struct{} - if qc.unsafeBuiltins != nil { - unsafe = qc.unsafeBuiltins - } else { - unsafe = qc.compiler.unsafeBuiltinsMap - } - errs := checkUnsafeBuiltins(unsafe, body) + errs := checkUnsafeBuiltins(qc.unsafeBuiltinsMap(), body) if len(errs) > 0 { return nil, errs } return body, nil } +func (qc *queryCompiler) unsafeBuiltinsMap() map[string]struct{} { + if qc.unsafeBuiltins != nil { + return qc.unsafeBuiltins + } + return qc.compiler.unsafeBuiltinsMap +} + func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Body, error) { errs := checkDeprecatedBuiltins(qc.compiler.deprecatedBuiltinsMap, body, qc.compiler.strict) if len(errs) > 0 { @@ -2498,7 +2499,7 @@ func (qc *queryCompiler) checkDeprecatedBuiltins(_ *QueryContext, body Body) (Bo func (qc *queryCompiler) rewriteWithModifiers(_ *QueryContext, body Body) (Body, error) { f := newEqualityFactory(newLocalVarGenerator("q", body)) - body, err := rewriteWithModifiersInBody(qc.compiler, f, body) + body, err := rewriteWithModifiersInBody(qc.compiler, qc.unsafeBuiltinsMap(), f, body) if err != nil { return nil, Errors{err} } @@ -4779,10 +4780,10 @@ func rewriteDeclaredVar(g *localVarGenerator, stack *localDeclaredVars, v Var, o // rewriteWithModifiersInBody will rewrite the body so that with modifiers do // not contain terms that require evaluation as values. If this function // encounters an invalid with modifier target then it will raise an error. -func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Body, *Error) { +func rewriteWithModifiersInBody(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, body Body) (Body, *Error) { var result Body for i := range body { - exprs, err := rewriteWithModifier(c, f, body[i]) + exprs, err := rewriteWithModifier(c, unsafeBuiltinsMap, f, body[i]) if err != nil { return nil, err } @@ -4797,11 +4798,11 @@ func rewriteWithModifiersInBody(c *Compiler, f *equalityFactory, body Body) (Bod return result, nil } -func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, *Error) { +func rewriteWithModifier(c *Compiler, unsafeBuiltinsMap map[string]struct{}, f *equalityFactory, expr *Expr) ([]*Expr, *Error) { var result []*Expr for i := range expr.With { - eval, err := validateWith(c, expr, i) + eval, err := validateWith(c, unsafeBuiltinsMap, expr, i) if err != nil { return nil, err } @@ -4816,7 +4817,7 @@ func rewriteWithModifier(c *Compiler, f *equalityFactory, expr *Expr) ([]*Expr, return append(result, expr), nil } -func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { +func validateWith(c *Compiler, unsafeBuiltinsMap map[string]struct{}, expr *Expr, i int) (bool, *Error) { target, value := expr.With[i].Target, expr.With[i].Value // Ensure that values that are built-ins are rewritten to Ref (not Var) @@ -4825,6 +4826,10 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { value.Value = Ref([]*Term{NewTerm(v)}) } } + isBuiltinRefOrVar, err := isBuiltinRefOrVar(c.builtins, unsafeBuiltinsMap, target) + if err != nil { + return false, err + } switch { case isDataRef(target): @@ -4848,15 +4853,15 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { if child := node.Child(ref[len(ref)-1].Value); child != nil { for _, v := range child.Values { if len(v.(*Rule).Head.Args) > 0 { - if validateWithFunctionValue(c.builtins, c.RuleTree, value) { - return false, nil + if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok { + return false, err // may be nil } } } } } case isInputRef(target): // ok, valid - case isBuiltinRefOrVar(c.builtins, target): + case isBuiltinRefOrVar: // NOTE(sr): first we ensure that parsed Var builtins (`count`, `concat`, etc) // are rewritten to their proper Ref convention @@ -4870,8 +4875,8 @@ func validateWith(c *Compiler, expr *Expr, i int) (bool, *Error) { return false, err } - if validateWithFunctionValue(c.builtins, c.RuleTree, value) { - return false, nil + if ok, err := validateWithFunctionValue(c.builtins, unsafeBuiltinsMap, c.RuleTree, value); err != nil || ok { + return false, err // may be nil } default: return false, NewError(TypeErr, target.Location, "with keyword target must reference existing %v, %v, or a function", InputRootDocument, DefaultRootDocument) @@ -4900,13 +4905,13 @@ func validateWithBuiltinTarget(bi *Builtin, target Ref, loc *location.Location) return nil } -func validateWithFunctionValue(bs map[string]*Builtin, ruleTree *TreeNode, value *Term) bool { +func validateWithFunctionValue(bs map[string]*Builtin, unsafeMap map[string]struct{}, ruleTree *TreeNode, value *Term) (bool, *Error) { if v, ok := value.Value.(Ref); ok { if ruleTree.Find(v) != nil { // ref exists in rule tree - return true + return true, nil } } - return isBuiltinRefOrVar(bs, value) + return isBuiltinRefOrVar(bs, unsafeMap, value) } func isInputRef(term *Term) bool { @@ -4927,13 +4932,16 @@ func isDataRef(term *Term) bool { return false } -func isBuiltinRefOrVar(bs map[string]*Builtin, term *Term) bool { +func isBuiltinRefOrVar(bs map[string]*Builtin, unsafeBuiltinsMap map[string]struct{}, term *Term) (bool, *Error) { switch v := term.Value.(type) { case Ref, Var: + if _, ok := unsafeBuiltinsMap[v.String()]; ok { + return false, NewError(CompileErr, term.Location, "with keyword replacing built-in function: target must not be unsafe: %q", v) + } _, ok := bs[v.String()] - return ok + return ok, nil } - return false + return false, nil } func isVirtual(node *TreeNode, ref Ref) bool { diff --git a/ast/compile_test.go b/ast/compile_test.go index 627b148044..be36c79d15 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -3970,6 +3970,7 @@ func TestCompilerRewriteWithValue(t *testing.T) { tests := []struct { note string input string + opts func(*Compiler) *Compiler expected string expectedRule *Rule wantErr error @@ -4075,6 +4076,26 @@ func TestCompilerRewriteWithValue(t *testing.T) { return r }(), }, + { + note: "built-in function: replaced by another built-in that's marked unsafe", + input: ` + q := is_object({"url": "https://httpbin.org", "method": "GET"}) + p { q with is_object as http.send } + `, + opts: func(c *Compiler) *Compiler { return c.WithUnsafeBuiltins(map[string]struct{}{"http.send": {}}) }, + wantErr: fmt.Errorf("rego_compile_error: with keyword replacing built-in function: target must not be unsafe: \"http.send\""), + }, + { + note: "non-built-in function: replaced by another built-in that's marked unsafe", + input: ` + r(_) = {} + q := r({"url": "https://httpbin.org", "method": "GET"}) + p { + q with r as http.send + }`, + opts: func(c *Compiler) *Compiler { return c.WithUnsafeBuiltins(map[string]struct{}{"http.send": {}}) }, + wantErr: fmt.Errorf("rego_compile_error: with keyword replacing built-in function: target must not be unsafe: \"http.send\""), + }, { note: "built-in function: valid, arity 1, non-compound name", input: ` @@ -4092,6 +4113,9 @@ func TestCompilerRewriteWithValue(t *testing.T) { for _, tc := range tests { t.Run(tc.note, func(t *testing.T) { c := NewCompiler() + if tc.opts != nil { + c = tc.opts(c) + } module := fixture + tc.input c.Modules["test"] = MustParseModule(module) compileStages(c, c.rewriteWithModifiers) @@ -6597,13 +6621,63 @@ func TestQueryCompilerWithStageAfterWithMetrics(t *testing.T) { } func TestQueryCompilerWithUnsafeBuiltins(t *testing.T) { - c := NewCompiler().WithUnsafeBuiltins(map[string]struct{}{ - "count": {}, - }) + tests := []struct { + note string + query string + compiler *Compiler + opts func(QueryCompiler) QueryCompiler + err string + }{ + { + note: "builtin unsafe via compiler", + query: "count([])", + compiler: NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}}), + err: "unsafe built-in function calls in expression: count", + }, + { + note: "builtin unsafe via query compiler", + query: "count([])", + compiler: NewCompiler(), + opts: func(qc QueryCompiler) QueryCompiler { + return qc.WithUnsafeBuiltins(map[string]struct{}{"count": {}}) + }, + err: "unsafe built-in function calls in expression: count", + }, + { + note: "builtin unsafe via compiler, 'with' mocking", + query: "is_array([]) with is_array as count", + compiler: NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}}), + err: `with keyword replacing built-in function: target must not be unsafe: "count"`, + }, + { + note: "builtin unsafe via query compiler, 'with' mocking", + query: "is_array([]) with is_array as count", + compiler: NewCompiler(), + opts: func(qc QueryCompiler) QueryCompiler { + return qc.WithUnsafeBuiltins(map[string]struct{}{"count": {}}) + }, + err: `with keyword replacing built-in function: target must not be unsafe: "count"`, + }, + } - _, err := c.QueryCompiler().WithUnsafeBuiltins(map[string]struct{}{}).Compile(MustParseBody("count([])")) - if err != nil { - t.Fatal(err) + for _, tc := range tests { + t.Run(tc.note, func(t *testing.T) { + qc := tc.compiler.QueryCompiler() + if tc.opts != nil { + qc = tc.opts(qc) + } + _, err := qc.Compile(MustParseBody(tc.query)) + var errs Errors + if !errors.As(err, &errs) { + t.Fatalf("expected error type %T, got %v %[2]T", errs, err) + } + if exp, act := 1, len(errs); exp != act { + t.Fatalf("expected %d error(s), got %d", exp, act) + } + if exp, act := tc.err, errs[0].Message; exp != act { + t.Errorf("expected message %q, got %q", exp, act) + } + }) } } diff --git a/rego/rego_test.go b/rego/rego_test.go index 46e86dbaee..32b3278c3d 100644 --- a/rego/rego_test.go +++ b/rego/rego_test.go @@ -1436,6 +1436,7 @@ func TestUnsafeBuiltins(t *testing.T) { ctx := context.Background() unsafeCountExpr := "unsafe built-in function calls in expression: count" + unsafeCountExprWith := `with keyword replacing built-in function: target must not be unsafe: "count"` t.Run("unsafe query", func(t *testing.T) { r := New( @@ -1447,6 +1448,16 @@ func TestUnsafeBuiltins(t *testing.T) { } }) + t.Run("unsafe query, 'with' replacement", func(t *testing.T) { + r := New( + Query(`is_array([1, 2, 3]) with is_array as count`), + UnsafeBuiltins(map[string]struct{}{"count": {}}), + ) + if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) { + t.Fatalf("Expected unsafe built-in error but got %v", err) + } + }) + t.Run("unsafe module", func(t *testing.T) { r := New( Query(`data.pkg.deny`), @@ -1462,6 +1473,36 @@ func TestUnsafeBuiltins(t *testing.T) { } }) + t.Run("unsafe module, 'with' replacement in query", func(t *testing.T) { + r := New( + Query(`data.pkg.deny with is_array as count`), + Module("pkg.rego", `package pkg + deny { + is_array(input.requests) > 10 + } + `), + UnsafeBuiltins(map[string]struct{}{"count": {}}), + ) + if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) { + t.Fatalf("Expected unsafe built-in error but got %v", err) + } + }) + + t.Run("unsafe module, 'with' replacement in module", func(t *testing.T) { + r := New( + Query(`data.pkg.deny`), + Module("pkg.rego", `package pkg + deny { + is_array(input.requests) > 10 with is_array as count + } + `), + UnsafeBuiltins(map[string]struct{}{"count": {}}), + ) + if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) { + t.Fatalf("Expected unsafe built-in error but got %v", err) + } + }) + t.Run("inherit in query", func(t *testing.T) { r := New( Compiler(ast.NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}})), @@ -1472,6 +1513,16 @@ func TestUnsafeBuiltins(t *testing.T) { } }) + t.Run("inherit in query, 'with' replacement", func(t *testing.T) { + r := New( + Compiler(ast.NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}})), + Query("is_array([]) with is_array as count"), + ) + if _, err := r.Eval(ctx); err == nil || !strings.Contains(err.Error(), unsafeCountExprWith) { + t.Fatalf("Expected unsafe built-in error but got %v", err) + } + }) + t.Run("override/disable in query", func(t *testing.T) { r := New( Compiler(ast.NewCompiler().WithUnsafeBuiltins(map[string]struct{}{"count": {}})),