Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(misconf): parse rego input once #6615

Merged
merged 4 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions pkg/iac/rego/exceptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package rego
import (
"context"
"fmt"

"github.com/open-policy-agent/opa/ast"
)

func (s *Scanner) isIgnored(ctx context.Context, namespace, ruleName string, input interface{}) (bool, error) {
func (s *Scanner) isIgnored(ctx context.Context, namespace, ruleName string, input ast.Value) (bool, error) {
if ignored, err := s.isNamespaceIgnored(ctx, namespace, input); err != nil {
return false, err
} else if ignored {
Expand All @@ -14,7 +16,7 @@ func (s *Scanner) isIgnored(ctx context.Context, namespace, ruleName string, inp
return s.isRuleIgnored(ctx, namespace, ruleName, input)
}

func (s *Scanner) isNamespaceIgnored(ctx context.Context, namespace string, input interface{}) (bool, error) {
func (s *Scanner) isNamespaceIgnored(ctx context.Context, namespace string, input ast.Value) (bool, error) {
exceptionQuery := fmt.Sprintf("data.namespace.exceptions.exception[_] == %q", namespace)
result, _, err := s.runQuery(ctx, exceptionQuery, input, true)
if err != nil {
Expand All @@ -23,7 +25,7 @@ func (s *Scanner) isNamespaceIgnored(ctx context.Context, namespace string, inpu
return result.Allowed(), nil
}

func (s *Scanner) isRuleIgnored(ctx context.Context, namespace, ruleName string, input interface{}) (bool, error) {
func (s *Scanner) isRuleIgnored(ctx context.Context, namespace, ruleName string, input ast.Value) (bool, error) {
exceptionQuery := fmt.Sprintf("endswith(%q, data.%s.exception[_][_])", ruleName, namespace)
result, _, err := s.runQuery(ctx, exceptionQuery, input, true)
if err != nil {
Expand Down
35 changes: 28 additions & 7 deletions pkg/iac/rego/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
"github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/util"

"github.com/aquasecurity/trivy/pkg/iac/debug"
"github.com/aquasecurity/trivy/pkg/iac/framework"
Expand Down Expand Up @@ -161,7 +162,7 @@ func (s *Scanner) SetParentDebugLogger(l debug.Logger) {
s.debug = l.Extend("rego")
}

func (s *Scanner) runQuery(ctx context.Context, query string, input interface{}, disableTracing bool) (rego.ResultSet, []string, error) {
func (s *Scanner) runQuery(ctx context.Context, query string, input ast.Value, disableTracing bool) (rego.ResultSet, []string, error) {

trace := (s.traceWriter != nil || s.tracePerResult) && !disableTracing

Expand All @@ -180,7 +181,7 @@ func (s *Scanner) runQuery(ctx context.Context, query string, input interface{},
}

if input != nil {
regoOptions = append(regoOptions, rego.Input(input))
regoOptions = append(regoOptions, rego.ParsedInput(input))
}

instance := rego.New(regoOptions...)
Expand Down Expand Up @@ -342,6 +343,14 @@ func isPolicyApplicable(staticMetadata *StaticMetadata, inputs ...Input) bool {
return false
}

func parseRawInput(input any) (ast.Value, error) {
if err := util.RoundTrip(&input); err != nil {
return nil, err
}

return ast.InterfaceToValue(input)
}
Comment on lines +346 to +352
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why this is better than simply passing the input as-is to the OPA engine? Wouldn't the engine call the roundtripper on its own?

I'm also a little wary of calling this on every single input. It could get very expensive. I almost wonder if we should write some benchmark tests of our own to evaluate this rather than just using Minikube repo as an input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rego is applied to the input data in three places, and each time the input data is parsed. So it makes sense to pass already parsed data. Why it can be expensive? Rego does the same thing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. After reading the code again, your logic makes sense to me.

We've been bitten by OPA/Rego performance issues in the past, so I'm always a little more skeptical changing things around on that end :)


func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs []Input, combined bool) (scan.Results, error) {

// handle combined evaluations if possible
Expand All @@ -354,7 +363,12 @@ func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs
qualified := fmt.Sprintf("data.%s.%s", namespace, rule)
for _, input := range inputs {
s.trace("INPUT", input)
if ignored, err := s.isIgnored(ctx, namespace, rule, input.Contents); err != nil {
parsedInput, err := parseRawInput(input.Contents)
if err != nil {
s.debug.Log("Error occurred while parsing input: %s", err)
continue
}
if ignored, err := s.isIgnored(ctx, namespace, rule, parsedInput); err != nil {
return nil, err
} else if ignored {
var result regoResult
Expand All @@ -364,7 +378,7 @@ func (s *Scanner) applyRule(ctx context.Context, namespace, rule string, inputs
results.AddIgnored(result)
continue
}
set, traces, err := s.runQuery(ctx, qualified, input.Contents, false)
set, traces, err := s.runQuery(ctx, qualified, parsedInput, false)
if err != nil {
return nil, err
}
Expand All @@ -388,9 +402,15 @@ func (s *Scanner) applyRuleCombined(ctx context.Context, namespace, rule string,
if len(inputs) == 0 {
return nil, nil
}

parsed, err := parseRawInput(inputs)
if err != nil {
return nil, fmt.Errorf("failed to parse input: %w", err)
}

var results scan.Results
qualified := fmt.Sprintf("data.%s.%s", namespace, rule)
if ignored, err := s.isIgnored(ctx, namespace, rule, inputs); err != nil {

if ignored, err := s.isIgnored(ctx, namespace, rule, parsed); err != nil {
return nil, err
} else if ignored {
for _, input := range inputs {
Expand All @@ -402,7 +422,8 @@ func (s *Scanner) applyRuleCombined(ctx context.Context, namespace, rule string,
}
return results, nil
}
set, traces, err := s.runQuery(ctx, qualified, inputs, false)
qualified := fmt.Sprintf("data.%s.%s", namespace, rule)
set, traces, err := s.runQuery(ctx, qualified, parsed, false)
if err != nil {
return nil, err
}
Expand Down