From bac22aa2a1fbd979b60640759f8f0d66522dc0f3 Mon Sep 17 00:00:00 2001 From: Liam Galvin Date: Thu, 22 Sep 2022 10:57:15 +0100 Subject: [PATCH] fix: Add support for recursive json schema elements to prevent fatal error Resolves #5166 Signed-off-by: Liam Galvin --- ast/compile.go | 121 +++++++++++++++++++++---------------- ast/compile_test.go | 143 ++++++++++++++++++++++++++++++++++++++++---- ast/schema.go | 2 +- ast/schema_test.go | 40 ++++++++++++- 4 files changed, 241 insertions(+), 65 deletions(-) diff --git a/ast/compile.go b/ast/compile.go index 6e863585f7..39fca0ac56 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -475,16 +475,16 @@ func (c *Compiler) GetArity(ref Ref) int { // // E.g., given the following module: // -// package a.b.c +// package a.b.c // -// p[k] = v { ... } # rule1 -// p[k1] = v1 { ... } # rule2 +// p[k] = v { ... } # rule1 +// p[k1] = v1 { ... } # rule2 // // The following calls yield the rules on the right. // -// GetRulesExact("data.a.b.c.p") => [rule1, rule2] -// GetRulesExact("data.a.b.c.p.x") => nil -// GetRulesExact("data.a.b.c") => nil +// GetRulesExact("data.a.b.c.p") => [rule1, rule2] +// GetRulesExact("data.a.b.c.p.x") => nil +// GetRulesExact("data.a.b.c") => nil func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) { node := c.RuleTree @@ -502,16 +502,16 @@ func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) { // // E.g., given the following module: // -// package a.b.c +// package a.b.c // -// p[k] = v { ... } # rule1 -// p[k1] = v1 { ... } # rule2 +// p[k] = v { ... } # rule1 +// p[k1] = v1 { ... } # rule2 // // The following calls yield the rules on the right. // -// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2] -// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2] -// GetRulesForVirtualDocument("data.a.b.c") => nil +// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2] +// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2] +// GetRulesForVirtualDocument("data.a.b.c") => nil func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) { node := c.RuleTree @@ -532,17 +532,17 @@ func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) { // // E.g., given the following module: // -// package a.b.c +// package a.b.c // -// p[x] = y { ... } # rule1 -// p[k] = v { ... } # rule2 -// q { ... } # rule3 +// p[x] = y { ... } # rule1 +// p[k] = v { ... } # rule2 +// q { ... } # rule3 // // The following calls yield the rules on the right. // -// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2] -// GetRulesWithPrefix("data.a.b.c.p.a") => nil -// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3] +// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2] +// GetRulesWithPrefix("data.a.b.c.p.a") => nil +// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3] func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) { node := c.RuleTree @@ -581,18 +581,18 @@ func extractRules(s []util.T) (rules []*Rule) { // // E.g., given the following module: // -// package a.b.c +// package a.b.c // -// p[x] = y { q[x] = y; ... } # rule1 -// q[x] = y { ... } # rule2 +// p[x] = y { q[x] = y; ... } # rule1 +// q[x] = y { ... } # rule2 // // The following calls yield the rules on the right. // -// GetRules("data.a.b.c.p") => [rule1] -// GetRules("data.a.b.c.p.x") => [rule1] -// GetRules("data.a.b.c.q") => [rule2] -// GetRules("data.a.b.c") => [rule1, rule2] -// GetRules("data.a.b.d") => nil +// GetRules("data.a.b.c.p") => [rule1] +// GetRules("data.a.b.c.p.x") => [rule1] +// GetRules("data.a.b.c.q") => [rule2] +// GetRules("data.a.b.c") => [rule1, rule2] +// GetRules("data.a.b.d") => nil func (c *Compiler) GetRules(ref Ref) (rules []*Rule) { set := map[*Rule]struct{}{} @@ -627,34 +627,34 @@ func (c *Compiler) GetRulesDynamic(ref Ref) []*Rule { // // E.g., given the following modules: // -// package a.b.c +// package a.b.c // -// r1 = 1 # rule1 +// r1 = 1 # rule1 // // and: // -// package a.d.c +// package a.d.c // -// r2 = 2 # rule2 +// r2 = 2 # rule2 // // The following calls yield the rules on the right. // -// GetRulesDynamicWithOpts("data.a[x].c[y]", opts) => [rule1, rule2] -// GetRulesDynamicWithOpts("data.a[x].c.r2", opts) => [rule2] -// GetRulesDynamicWithOpts("data.a.b[x][y]", opts) => [rule1] +// GetRulesDynamicWithOpts("data.a[x].c[y]", opts) => [rule1, rule2] +// GetRulesDynamicWithOpts("data.a[x].c.r2", opts) => [rule2] +// GetRulesDynamicWithOpts("data.a.b[x][y]", opts) => [rule1] // // Using the RulesOptions parameter, the inclusion of hidden modules can be // controlled: // // With // -// package system.main +// package system.main // -// r3 = 3 # rule3 +// r3 = 3 # rule3 // // We'd get this result: // -// GetRulesDynamicWithOpts("data[x]", RulesOptions{IncludeHiddenModules: true}) => [rule1, rule2, rule3] +// GetRulesDynamicWithOpts("data[x]", RulesOptions{IncludeHiddenModules: true}) => [rule1, rule2, rule3] // // Without the options, it would be excluded. func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule { @@ -1053,7 +1053,17 @@ func mergeSchemas(schemas ...*gojsonschema.SubSchema) (*gojsonschema.SubSchema, return result, nil } -func parseSchema(schema interface{}) (types.Type, error) { +type schemaParser struct { + definitionCache map[string][]*types.StaticProperty +} + +func newSchemaParser() *schemaParser { + return &schemaParser{ + definitionCache: map[string][]*types.StaticProperty{}, + } +} + +func (parser *schemaParser) parseSchema(schema interface{}) (types.Type, error) { subSchema, ok := schema.(*gojsonschema.SubSchema) if !ok { return nil, fmt.Errorf("unexpected schema type %v", subSchema) @@ -1061,7 +1071,10 @@ func parseSchema(schema interface{}) (types.Type, error) { // Handle referenced schemas, returns directly when a $ref is found if subSchema.RefSchema != nil { - return parseSchema(subSchema.RefSchema) + if existing, ok := parser.definitionCache[subSchema.RefSchema.ID.String()]; ok { + return types.NewObject(existing, nil), nil + } + return parser.parseSchema(subSchema.RefSchema) } // Handle anyOf @@ -1073,7 +1086,7 @@ func parseSchema(schema interface{}) (types.Type, error) { copySchema := *subSchema copySchemaRef := ©Schema copySchemaRef.AnyOf = nil - coreType, err := parseSchema(copySchemaRef) + coreType, err := parser.parseSchema(copySchemaRef) if err != nil { return nil, fmt.Errorf("unexpected schema type %v: %w", subSchema, err) } @@ -1088,7 +1101,7 @@ func parseSchema(schema interface{}) (types.Type, error) { // Iterate through every property of AnyOf and add it to orType for _, pSchema := range subSchema.AnyOf { - newtype, err := parseSchema(pSchema) + newtype, err := parser.parseSchema(pSchema) if err != nil { return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err) } @@ -1111,12 +1124,12 @@ func parseSchema(schema interface{}) (types.Type, error) { if err != nil { return nil, err } - return parseSchema(objectOrArrayResult) + return parser.parseSchema(objectOrArrayResult) } else if subSchema.Types.String() != allOfResult.Types.String() { return nil, fmt.Errorf("unable to merge these schemas") } } - return parseSchema(allOfResult) + return parser.parseSchema(allOfResult) } if subSchema.Types.IsTyped() { @@ -1133,11 +1146,17 @@ func parseSchema(schema interface{}) (types.Type, error) { if len(subSchema.PropertiesChildren) > 0 { staticProps := make([]*types.StaticProperty, 0, len(subSchema.PropertiesChildren)) for _, pSchema := range subSchema.PropertiesChildren { - newtype, err := parseSchema(pSchema) + staticProps = append(staticProps, types.NewStaticProperty(pSchema.Property, nil)) + } + if subSchema.Parent != nil { + parser.definitionCache[subSchema.ID.String()] = staticProps + } + for i, pSchema := range subSchema.PropertiesChildren { + newtype, err := parser.parseSchema(pSchema) if err != nil { return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err) } - staticProps = append(staticProps, types.NewStaticProperty(pSchema.Property, newtype)) + staticProps[i].Value = newtype } return types.NewObject(staticProps, nil), nil } @@ -1147,7 +1166,7 @@ func parseSchema(schema interface{}) (types.Type, error) { if len(subSchema.ItemsChildren) > 0 { if subSchema.ItemsChildrenIsSingleSchema { iSchema := subSchema.ItemsChildren[0] - newtype, err := parseSchema(iSchema) + newtype, err := parser.parseSchema(iSchema) if err != nil { return nil, fmt.Errorf("unexpected schema type %v", iSchema) } @@ -1156,7 +1175,7 @@ func parseSchema(schema interface{}) (types.Type, error) { newTypes := make([]types.Type, 0, len(subSchema.ItemsChildren)) for i := 0; i != len(subSchema.ItemsChildren); i++ { iSchema := subSchema.ItemsChildren[i] - newtype, err := parseSchema(iSchema) + newtype, err := parser.parseSchema(iSchema) if err != nil { return nil, fmt.Errorf("unexpected schema type %v", iSchema) } @@ -1171,11 +1190,11 @@ func parseSchema(schema interface{}) (types.Type, error) { // Assume types if not specified in schema if len(subSchema.PropertiesChildren) > 0 { if err := subSchema.Types.Add("object"); err == nil { - return parseSchema(subSchema) + return parser.parseSchema(subSchema) } } else if len(subSchema.ItemsChildren) > 0 { if err := subSchema.Types.Add("array"); err == nil { - return parseSchema(subSchema) + return parser.parseSchema(subSchema) } } @@ -1565,11 +1584,11 @@ func checkVoidCalls(env *TypeEnv, x interface{}) Errors { // // For example, given the following print statement: // -// print("the value of x is:", input.x) +// print("the value of x is:", input.x) // // The expression would be rewritten to: // -// print({__local0__ | __local0__ = "the value of x is:"}, {__local1__ | __local1__ = input.x}) +// print({__local0__ | __local0__ = "the value of x is:"}, {__local1__ | __local1__ = input.x}) func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals VarSet, body Body) Errors { var errs Errors diff --git a/ast/compile_test.go b/ast/compile_test.go index 8303b88fb2..bc23bf30a5 100644 --- a/ast/compile_test.go +++ b/ast/compile_test.go @@ -6,6 +6,7 @@ package ast import ( "bytes" + "encoding/json" "errors" "fmt" "reflect" @@ -1449,18 +1450,18 @@ func TestIllegalFunctionCallRewrite(t *testing.T) { expectedErrors []string }{ /*{ - note: "function call override in function value", - module: `package test - foo(x) := x - - p := foo(bar) { - #foo := 1 - bar := 2 - }`, - expectedErrors: []string{ - "undefined function foo", - }, - },*/ + note: "function call override in function value", + module: `package test + foo(x) := x + + p := foo(bar) { + #foo := 1 + bar := 2 + }`, + expectedErrors: []string{ + "undefined function foo", + }, + },*/ { note: "function call override in array comprehension value", module: `package test @@ -7332,3 +7333,121 @@ func TestKeepModules(t *testing.T) { } }) } + +// see https://github.com/open-policy-agent/opa/issues/5166 +func TestCompilerWithRecursiveSchema(t *testing.T) { + + jsonSchema := `{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://github.com/open-policy-agent/opa/issues/5166", + "type": "object", + "properties": { + "Something": { + "$ref": "#/$defs/X" + } + }, + "$defs": { + "X": { + "type": "object", + "properties": { + "Name": { "type": "string" }, + "Y": { + "$ref": "#/$defs/Y" + } + } + }, + "Y": { + "type": "object", + "properties": { + "X": { + "$ref": "#/$defs/X" + } + } + } + } +}` + + exampleModule := `# METADATA +# schemas: +# - input: schema.input +package opa.recursion + +deny { + input.Something.Y.X.Name == "Something" +} +` + + c := NewCompiler() + var schema interface{} + if err := json.Unmarshal([]byte(jsonSchema), &schema); err != nil { + t.Fatal(err) + } + schemaSet := NewSchemaSet() + schemaSet.Put(MustParseRef("schema.input"), schema) + c.WithSchemas(schemaSet) + + m := MustParseModuleWithOpts(exampleModule, ParserOptions{ProcessAnnotation: true}) + c.Compile(map[string]*Module{"testMod": m}) + if c.Failed() { + t.Errorf("Expected compilation to succeed, but got errors: %v", c.Errors) + } +} + +// see https://github.com/open-policy-agent/opa/issues/5166 +func TestCompilerWithRecursiveSchemaAndInvalidSource(t *testing.T) { + + jsonSchema := `{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://github.com/open-policy-agent/opa/issues/5166", + "type": "object", + "properties": { + "Something": { + "$ref": "#/$defs/X" + } + }, + "$defs": { + "X": { + "type": "object", + "properties": { + "Name": { "type": "string" }, + "Y": { + "$ref": "#/$defs/Y" + } + } + }, + "Y": { + "type": "object", + "properties": { + "X": { + "$ref": "#/$defs/X" + } + } + } + } +}` + + exampleModule := `# METADATA +# schemas: +# - input: schema.input +package opa.recursion + +deny { + input.Something.Y.X.ThisDoesNotExist == "Something" +} +` + + c := NewCompiler() + var schema interface{} + if err := json.Unmarshal([]byte(jsonSchema), &schema); err != nil { + t.Fatal(err) + } + schemaSet := NewSchemaSet() + schemaSet.Put(MustParseRef("schema.input"), schema) + c.WithSchemas(schemaSet) + + m := MustParseModuleWithOpts(exampleModule, ParserOptions{ProcessAnnotation: true}) + c.Compile(map[string]*Module{"testMod": m}) + if !c.Failed() { + t.Errorf("Expected compilation to fail, but it succeeded") + } +} diff --git a/ast/schema.go b/ast/schema.go index 76bd475677..8c96ac624e 100644 --- a/ast/schema.go +++ b/ast/schema.go @@ -54,7 +54,7 @@ func loadSchema(raw interface{}, allowNet []string) (types.Type, error) { return nil, err } - tpe, err := parseSchema(jsonSchema.RootSchema) + tpe, err := newSchemaParser().parseSchema(jsonSchema.RootSchema) if err != nil { return nil, fmt.Errorf("type checking: %w", err) } diff --git a/ast/schema_test.go b/ast/schema_test.go index 9f1d0943cd..ccfbc69d96 100644 --- a/ast/schema_test.go +++ b/ast/schema_test.go @@ -460,7 +460,7 @@ func TestParseSchemaWithSchemaBadSchema(t *testing.T) { if err != nil { t.Fatalf("Unable to compile schema: %v", err) } - newtype, err := parseSchema(jsonSchema) // Did not pass the subschema + newtype, err := newSchemaParser().parseSchema(jsonSchema) // Did not pass the subschema if err == nil { t.Fatalf("Expected parseSchema() = error, got nil") } @@ -814,6 +814,16 @@ func TestAnyOfArrayMissing(t *testing.T) { } } +func TestRecursiveSchema(t *testing.T) { + c := NewCompiler() + schemaSet := NewSchemaSet() + schemaSet.Put(SchemaRootRef, recursiveElements) + c.WithSchemas(schemaSet) + if c.schemaSet == nil { + t.Fatalf("Did not correctly compile an object schema with recursive elements") + } +} + const objectSchema = `{ "$schema": "http://json-schema.org/draft-07/schema", "$id": "http://example.com/example.json", @@ -1620,3 +1630,31 @@ const allOfSchemaWithUnevenArray = `{ } ] }` + +const recursiveElements = `{ + "type": "object", + "properties": { + "Something": { + "$ref": "#/$defs/X" + } + }, + "$defs": { + "X": { + "type": "object", + "properties": { + "Y": { + "$ref": "#/$defs/Y" + } + } + }, + "Y": { + "type": "object", + "properties": { + "X": { + "$ref": "#/$defs/X" + } + } + } + } +} +`