diff --git a/pkg/tools/opa/policy.go b/pkg/tools/opa/policy.go index acfd0d906..9ffb64e4d 100644 --- a/pkg/tools/opa/policy.go +++ b/pkg/tools/opa/policy.go @@ -1,6 +1,6 @@ // Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // -// Copyright (c) 2023-2024 Cisco and/or its affiliates. +// Copyright (c) 2023 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -20,10 +20,10 @@ package opa import ( "context" - "fmt" "os" "path/filepath" "strings" + "sync" "github.com/open-policy-agent/opa/rego" "github.com/pkg/errors" @@ -77,12 +77,15 @@ func WithNamedPolicyFromSource(name, source, query string, checkQuery CheckQuery // AuthorizationPolicy checks that passed tokens are valid type AuthorizationPolicy struct { + initErr error name string policyFilePath string policySource string pkg string query string + evalQuery *rego.PreparedEvalQuery checker CheckAccessFunc + once sync.Once } // Name returns AuthorizationPolicy name @@ -96,11 +99,10 @@ func (d *AuthorizationPolicy) Check(ctx context.Context, model interface{}) erro if err != nil { return err } - evalQuery, intErr := d.init() - if intErr != nil { + if intErr := d.init(); intErr != nil { return intErr } - rs, err := evalQuery.Eval(ctx, rego.EvalInput(input)) + rs, err := d.evalQuery.Eval(ctx, rego.EvalInput(input)) if err != nil { return status.Error(codes.Internal, err.Error()) } @@ -114,23 +116,33 @@ func (d *AuthorizationPolicy) Check(ctx context.Context, model interface{}) erro return nil } -func (d *AuthorizationPolicy) init() (*rego.PreparedEvalQuery, error) { - if d.query == "" { - d.query = strings.TrimSuffix(filepath.Base(d.policyFilePath), filepath.Ext(d.policyFilePath)) - } - if err := d.loadSource(); err != nil { - return nil, err - } - if err := d.checkModule(); err != nil { - return nil, err +func (d *AuthorizationPolicy) init() error { + d.once.Do(func() { + if d.query == "" { + d.query = strings.TrimSuffix(filepath.Base(d.policyFilePath), filepath.Ext(d.policyFilePath)) + } + if d.initErr = d.loadSource(); d.initErr != nil { + return + } + if d.initErr = d.checkModule(); d.initErr != nil { + return + } + var r rego.PreparedEvalQuery + r, d.initErr = rego.New( + rego.Query(strings.Join([]string{"data", d.pkg, d.query}, ".")), + rego.Module(d.pkg, d.policySource)).PrepareForEval(context.Background()) + if d.initErr != nil { + return + } + d.evalQuery = &r + }) + if d.initErr != nil { + return d.initErr } - r, err := rego.New( - rego.Query(strings.Join([]string{"data", d.pkg, d.query}, ".")), - rego.Module(d.pkg, d.policySource)).PrepareForEval(context.Background()) - if err != nil { - return nil, errors.Wrap(err, fmt.Sprintf("policy %v is not compiled", d.policySource)) + if d.evalQuery == nil { + return errors.Errorf("policy %v is not compiled", d.policySource) } - return &r, nil + return nil } func (d *AuthorizationPolicy) loadSource() error {