diff --git a/internal/dispatch/graph/computecheck_test.go b/internal/dispatch/graph/computecheck_test.go index 3276fa9d1e..9fbc95c861 100644 --- a/internal/dispatch/graph/computecheck_test.go +++ b/internal/dispatch/graph/computecheck_test.go @@ -469,7 +469,7 @@ func TestComputeCheckWithCaveats(t *testing.T) { }`, map[string]caveatDefinition{ "attributes_match": { - "expected.all(x, expected[x] == provided[x])", + "expected.isSubtreeOf(provided)", map[string]types.VariableType{ "expected": types.MapType(types.AnyType), "provided": types.MapType(types.AnyType), @@ -525,8 +525,9 @@ func TestComputeCheckWithCaveats(t *testing.T) { "provided": map[string]any{ "type": "backend", "region": "us", "team": "shop", "additional_attrs": map[string]any{ - "tag1": 100, + "tag1": 100.0, "tag2": false, + "tag3": "hi", }, }, }, diff --git a/pkg/caveats/env.go b/pkg/caveats/env.go index 19014c929a..7d0cfdf938 100644 --- a/pkg/caveats/env.go +++ b/pkg/caveats/env.go @@ -66,8 +66,6 @@ func (e *Environment) asCelEnvironment() (*cel.Env, error) { opts = append(opts, customTypeOpts...) } - opts = append(opts, types.TypeMethods...) - // Set options. // DefaultUTCTimeZone: ensure all timestamps are evaluated at UTC opts = append(opts, cel.DefaultUTCTimeZone(true)) diff --git a/pkg/caveats/types/map.go b/pkg/caveats/types/map.go new file mode 100644 index 0000000000..46905a1117 --- /dev/null +++ b/pkg/caveats/types/map.go @@ -0,0 +1,46 @@ +package types + +import ( + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +func init() { + subtreeFunction := cel.Function("isSubtreeOf", + cel.MemberOverload("subtree_maps", + []*cel.Type{cel.MapType(cel.StringType, cel.AnyType), cel.MapType(cel.StringType, cel.AnyType)}, + cel.BoolType, + cel.FunctionBinding(func(arg ...ref.Val) ref.Val { + map0 := arg[0].Value().(map[string]any) + map1 := arg[1].Value().(map[string]any) + return types.Bool(subtree(map0, map1)) + }))) + + CustomTypes["__map_subtree"] = []cel.EnvOption{subtreeFunction} +} + +func subtree(map0 map[string]any, map1 map[string]any) bool { + for k, v := range map0 { + val, ok := map1[k] + if !ok { + return false + } + nestedMap0, ok := v.(map[string]any) + if ok { + nestedMap1, ok := val.(map[string]any) + if !ok { + return false + } + nestedResult := subtree(nestedMap0, nestedMap1) + if !nestedResult { + return false + } + } else { + if v != val { + return false + } + } + } + return true +} diff --git a/pkg/caveats/types/map_test.go b/pkg/caveats/types/map_test.go new file mode 100644 index 0000000000..c409508f53 --- /dev/null +++ b/pkg/caveats/types/map_test.go @@ -0,0 +1,61 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMapSubtree(t *testing.T) { + tcs := []struct { + map1 map[string]any + map2 map[string]any + subtree bool + }{ + { + map[string]any{"a": 1}, + map[string]any{"a": 1}, + true, + }, + { + map[string]any{"a": 1, "b": 2}, + map[string]any{"a": 1}, + false, + }, + { + map[string]any{"a": 1}, + map[string]any{"a": 1, "b": 1}, + true, + }, + { + map[string]any{"a": 1, "b": map[string]any{"a": 1}}, + map[string]any{"a": 1}, + false, + }, + { + map[string]any{"a": 1, "b": map[string]any{"a": 1}}, + map[string]any{"a": 1, "b": map[string]any{"a": 1}}, + true, + }, + { + map[string]any{"a": 1, "b": map[string]any{"a": 1}}, + map[string]any{"a": 1, "b": map[string]any{"a": 1, "b": 1}}, + true, + }, + { + map[string]any{"a": 1, "b": map[string]any{"a": 1}}, + map[string]any{"a": 1, "b": map[string]any{"a": "1", "b": 1}}, + false, + }, + { + map[string]any{"a": 1, "b": map[string]any{"a": 1}}, + map[string]any{"a": 1, "b": map[string]any{"a": 1, "b": map[string]any{}}}, + true, + }, + } + for _, tt := range tcs { + t.Run("", func(t *testing.T) { + require.Equal(t, tt.subtree, subtree(tt.map1, tt.map2)) + }) + } +}