Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for context proto declarations #1006

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions policy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ go_library(
"//common/types/ref:go_default_library",
"//ext:go_default_library",
"@in_gopkg_yaml_v3//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
"@org_golang_google_protobuf//reflect/protoregistry:go_default_library",
],
)

Expand All @@ -58,6 +61,7 @@ go_test(
deps = [
"//cel:go_default_library",
"//common/types:go_default_library",
"//interpreter:go_default_library",
"//common/types/ref:go_default_library",
"//test/proto3pb:go_default_library",
"@in_gopkg_yaml_v3//:go_default_library",
Expand Down
54 changes: 46 additions & 8 deletions policy/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ package policy

import (
"fmt"
"reflect"
"strings"
"testing"

"google.golang.org/protobuf/proto"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
)

func TestCompile(t *testing.T) {
Expand Down Expand Up @@ -205,14 +209,31 @@ func (r *runner) run(t *testing.T) {
tc := tst
t.Run(fmt.Sprintf("%s/%s/%s", r.name, section, tc.Name), func(t *testing.T) {
input := map[string]any{}
var err error
var activation interpreter.Activation
for k, v := range tc.Input {
if v.Expr == "" {
input[k] = v.Value
if v.Expr != "" {
input[k] = r.eval(t, v.Expr)
continue
}
input[k] = r.eval(t, v.Expr)
if v.ContextExpr != "" {
ctx, err := r.eval(t, v.ContextExpr).ConvertToNative(
reflect.TypeOf(((*proto.Message)(nil))).Elem())
if err != nil {
t.Fatalf("context variable is not a valid proto: %v", err)
}
activation, err = cel.ContextProtoVars(ctx.(proto.Message))
break
}
input[k] = v.Value
}
if activation == nil {
activation, err = interpreter.NewActivation(input)
if err != nil {
t.Fatalf("interpreter.NewActivation(input) failed: %v", err)
}
}
out, _, err := r.prg.Eval(input)
out, _, err := r.prg.Eval(activation)
if err != nil {
t.Fatalf("prg.Eval(input) failed: %v", err)
}
Expand Down Expand Up @@ -241,15 +262,32 @@ func (r *runner) bench(b *testing.B) {
tc := tst
b.Run(fmt.Sprintf("%s/%s/%s", r.name, section, tc.Name), func(b *testing.B) {
input := map[string]any{}
var err error
var activation interpreter.Activation
for k, v := range tc.Input {
if v.Expr == "" {
input[k] = v.Value
if v.Expr != "" {
input[k] = r.eval(b, v.Expr)
continue
}
input[k] = r.eval(b, v.Expr)
if v.ContextExpr != "" {
ctx, err := r.eval(b, v.ContextExpr).ConvertToNative(
reflect.TypeOf(((*proto.Message)(nil))).Elem())
if err != nil {
b.Fatalf("context variable is not a valid proto: %v", err)
}
activation, err = cel.ContextProtoVars(ctx.(proto.Message))
break
}
input[k] = v.Value
}
if activation == nil {
activation, err = interpreter.NewActivation(input)
if err != nil {
b.Fatalf("interpreter.NewActivation(input) failed: %v", err)
}
}
for i := 0; i < b.N; i++ {
_, _, err := r.prg.Eval(input)
_, _, err := r.prg.Eval(activation)
if err != nil {
b.Fatalf("policy eval failed: %v", err)
}
Expand Down
30 changes: 24 additions & 6 deletions policy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
package policy

import (
"errors"
"fmt"
"math"
"strconv"

"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/ext"
)
Expand Down Expand Up @@ -105,20 +109,34 @@ func (ec *ExtensionConfig) AsEnvOption(baseEnv *cel.Env) (cel.EnvOption, error)

// VariableDecl represents a YAML serializable CEL variable declaration.
type VariableDecl struct {
Name string `yaml:"name"`
Type *TypeDecl `yaml:"type"`
Name string `yaml:"name"`
Type *TypeDecl `yaml:"type"`
ContextProto string `yaml:"context_proto"`
}

// AsEnvOption converts a VariableDecl type to a CEL environment option.
//
// Note, variable definitions with differing type definitions will result in an error during
// the compile step.
func (vd *VariableDecl) AsEnvOption(baseEnv *cel.Env) (cel.EnvOption, error) {
t, err := vd.Type.AsCELType(baseEnv)
if err != nil {
return nil, err
if vd.Name != "" {
t, err := vd.Type.AsCELType(baseEnv)
if err != nil {
return nil, fmt.Errorf("invalid variable type for '%s': %w", vd.Name, err)
}
return cel.Variable(vd.Name, t), nil
}
if vd.ContextProto != "" {
if _, found := baseEnv.CELTypeProvider().FindStructType(vd.ContextProto); !found {
return nil, fmt.Errorf("could not find context proto type name: %s", vd.ContextProto)
}
messageType, err := protoregistry.GlobalTypes.FindMessageByName(protoreflect.FullName(vd.ContextProto))
if err == protoregistry.NotFound {
return nil, fmt.Errorf("could not find context proto type name: %s", vd.ContextProto)
}
return cel.DeclareContextProto(messageType.Descriptor()), nil
}
return cel.Variable(vd.Name, t), nil
return nil, errors.New("invalid variable, must set 'name' or 'context_proto' field")
}

// TypeDecl represents a YAML serializable CEL type reference.
Expand Down
25 changes: 20 additions & 5 deletions policy/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,15 @@ variables:
- name: "bad_type"
type:
type_name: "strings"`,
err: "undefined type name: strings",
err: "invalid variable type for 'bad_type': undefined type name: strings",
},
{
config: `
variables:
- name: "bad_list"
type:
type_name: "list"`,
err: "list type has unexpected param count: 0",
err: "invalid variable type for 'bad_list': list type has unexpected param count: 0",
},
{
config: `
Expand All @@ -146,7 +146,7 @@ variables:
type_name: "map"
params:
- type_name: "string"`,
err: "map type has unexpected param count: 1",
err: "invalid variable type for 'bad_map': map type has unexpected param count: 1",
},
{
config: `
Expand All @@ -156,7 +156,7 @@ variables:
type_name: "list"
params:
- type_name: "number"`,
err: "undefined type name: number",
err: "invalid variable type for 'bad_list_type_param': undefined type name: number",
},
{
config: `
Expand All @@ -167,8 +167,23 @@ variables:
params:
- type_name: "string"
- type_name: "optional"`,
err: "undefined type name: optional",
err: "invalid variable type for 'bad_map_type_param': undefined type name: optional",
},
{
config: `
variables:
- context_proto: "bad.proto.MessageType"
`,
err: "could not find context proto type name: bad.proto.MessageType",
},
{
config: `
variables:
- type:
type_name: "no variable name"`,
err: "invalid variable, must set 'name' or 'context_proto' field",
},

{
config: `
functions:
Expand Down
10 changes: 8 additions & 2 deletions policy/conformance.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ type TestCase struct {

// TestInput represents an input literal value or expression.
type TestInput struct {
Value any `yaml:"value"`
Expr string `yaml:"expr"`
// Value is a simple literal value.
Value any `yaml:"value"`

// Expr is a CEL expression based input.
Expr string `yaml:"expr"`

// ContextExpr is a CEL expression which is used as cel.ContextProtoVars
ContextExpr string `yaml:"context_expr"`
}
13 changes: 13 additions & 0 deletions policy/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ var (
: (!(resource.origin in variables.permitted_regions)
? optional.of({"banned": "unconfigured_region"}) : optional.none()))`,
},
{
name: "context_pb",
expr: `
(single_int32 > google.expr.proto3.test.TestAllTypes{single_int64: 10}.single_int64)
? optional.of("invalid spec, got single_int32=%d, wanted <= 10".format([single_int32]))
: ((standalone_enum == google.expr.proto3.test.TestAllTypes.NestedEnum.BAR ||
google.expr.proto3.test.ImportedGlobalEnum.IMPORT_BAR in imported_enums)
? optional.of("invalid spec, neither nested nor imported enums may refer to BAR or IMPORT_BAR")
: optional.none())`,
envOpts: []cel.EnvOption{
cel.Types(&proto3pb.TestAllTypes{}),
},
},
{
name: "pb",
expr: `
Expand Down
21 changes: 21 additions & 0 deletions policy/testdata/context_pb/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: "context_pb"
container: "google.expr.proto3"
extensions:
- name: "strings"
version: 2
variables:
- context_proto: "google.expr.proto3.test.TestAllTypes"
33 changes: 33 additions & 0 deletions policy/testdata/context_pb/policy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: "context_pb"

imports:
- name: google.expr.proto3.test.TestAllTypes
- name: google.expr.proto3.test.TestAllTypes.NestedEnum
- name: |
google.expr.proto3.test.ImportedGlobalEnum

rule:
match:
- condition: >
single_int32 > TestAllTypes{single_int64: 10}.single_int64
output: |
"invalid spec, got single_int32=%d, wanted <= 10".format([single_int32])
- condition: >
standalone_enum == NestedEnum.BAR ||
ImportedGlobalEnum.IMPORT_BAR in imported_enums
output: |
"invalid spec, neither nested nor imported enums may refer to BAR or IMPORT_BAR"
33 changes: 33 additions & 0 deletions policy/testdata/context_pb/tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

description: "Protobuf input tests"
section:
- name: "valid"
tests:
- name: "good spec"
input:
spec:
context_expr: >
test.TestAllTypes{single_int32: 10}
output: "optional.none()"
- name: "invalid"
tests:
- name: "bad spec"
input:
spec:
context_expr: >
test.TestAllTypes{single_int32: 11}
output: >
"invalid spec, got single_int32=11, wanted <= 10"