diff --git a/pkg/tools/opa/policy.go b/pkg/tools/opa/policy.go index 9ffb64e4d..acfd0d906 100644 --- a/pkg/tools/opa/policy.go +++ b/pkg/tools/opa/policy.go @@ -1,6 +1,6 @@ // Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // -// Copyright (c) 2023 Cisco and/or its affiliates. +// Copyright (c) 2023-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -20,10 +20,10 @@ package opa import ( "context" + "fmt" "os" "path/filepath" "strings" - "sync" "github.com/open-policy-agent/opa/rego" "github.com/pkg/errors" @@ -77,15 +77,12 @@ func WithNamedPolicyFromSource(name, source, query string, checkQuery CheckQuery // AuthorizationPolicy checks that passed tokens are valid type AuthorizationPolicy struct { - initErr error name string policyFilePath string policySource string pkg string query string - evalQuery *rego.PreparedEvalQuery checker CheckAccessFunc - once sync.Once } // Name returns AuthorizationPolicy name @@ -99,10 +96,11 @@ func (d *AuthorizationPolicy) Check(ctx context.Context, model interface{}) erro if err != nil { return err } - if intErr := d.init(); intErr != nil { + evalQuery, intErr := d.init() + if intErr != nil { return intErr } - rs, err := d.evalQuery.Eval(ctx, rego.EvalInput(input)) + rs, err := evalQuery.Eval(ctx, rego.EvalInput(input)) if err != nil { return status.Error(codes.Internal, err.Error()) } @@ -116,33 +114,23 @@ func (d *AuthorizationPolicy) Check(ctx context.Context, model interface{}) erro return nil } -func (d *AuthorizationPolicy) init() error { - d.once.Do(func() { - if d.query == "" { - d.query = strings.TrimSuffix(filepath.Base(d.policyFilePath), filepath.Ext(d.policyFilePath)) - } - if d.initErr = d.loadSource(); d.initErr != nil { - return - } - if d.initErr = d.checkModule(); d.initErr != nil { - return - } - var r rego.PreparedEvalQuery - r, d.initErr = rego.New( - rego.Query(strings.Join([]string{"data", d.pkg, d.query}, ".")), - rego.Module(d.pkg, d.policySource)).PrepareForEval(context.Background()) - if d.initErr != nil { - return - } - d.evalQuery = &r - }) - if d.initErr != nil { - return d.initErr +func (d *AuthorizationPolicy) init() (*rego.PreparedEvalQuery, error) { + if d.query == "" { + d.query = strings.TrimSuffix(filepath.Base(d.policyFilePath), filepath.Ext(d.policyFilePath)) } - if d.evalQuery == nil { - return errors.Errorf("policy %v is not compiled", d.policySource) + if err := d.loadSource(); err != nil { + return nil, err } - return nil + if err := d.checkModule(); err != nil { + return nil, err + } + r, err := rego.New( + rego.Query(strings.Join([]string{"data", d.pkg, d.query}, ".")), + rego.Module(d.pkg, d.policySource)).PrepareForEval(context.Background()) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("policy %v is not compiled", d.policySource)) + } + return &r, nil } func (d *AuthorizationPolicy) loadSource() error {