Skip to content

Commit

Permalink
fix: Add support for recursive json schema elements to prevent fatal …
Browse files Browse the repository at this point in the history
…error

Resolves open-policy-agent#5166

Signed-off-by: Liam Galvin <liam.galvin@aquasec.com>
  • Loading branch information
liamg committed Sep 22, 2022
1 parent f266848 commit bac22aa
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 65 deletions.
121 changes: 70 additions & 51 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,16 +475,16 @@ func (c *Compiler) GetArity(ref Ref) int {
//
// E.g., given the following module:
//
// package a.b.c
// package a.b.c
//
// p[k] = v { ... } # rule1
// p[k1] = v1 { ... } # rule2
// p[k] = v { ... } # rule1
// p[k1] = v1 { ... } # rule2
//
// The following calls yield the rules on the right.
//
// GetRulesExact("data.a.b.c.p") => [rule1, rule2]
// GetRulesExact("data.a.b.c.p.x") => nil
// GetRulesExact("data.a.b.c") => nil
// GetRulesExact("data.a.b.c.p") => [rule1, rule2]
// GetRulesExact("data.a.b.c.p.x") => nil
// GetRulesExact("data.a.b.c") => nil
func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) {
node := c.RuleTree

Expand All @@ -502,16 +502,16 @@ func (c *Compiler) GetRulesExact(ref Ref) (rules []*Rule) {
//
// E.g., given the following module:
//
// package a.b.c
// package a.b.c
//
// p[k] = v { ... } # rule1
// p[k1] = v1 { ... } # rule2
// p[k] = v { ... } # rule1
// p[k1] = v1 { ... } # rule2
//
// The following calls yield the rules on the right.
//
// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2]
// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2]
// GetRulesForVirtualDocument("data.a.b.c") => nil
// GetRulesForVirtualDocument("data.a.b.c.p") => [rule1, rule2]
// GetRulesForVirtualDocument("data.a.b.c.p.x") => [rule1, rule2]
// GetRulesForVirtualDocument("data.a.b.c") => nil
func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) {

node := c.RuleTree
Expand All @@ -532,17 +532,17 @@ func (c *Compiler) GetRulesForVirtualDocument(ref Ref) (rules []*Rule) {
//
// E.g., given the following module:
//
// package a.b.c
// package a.b.c
//
// p[x] = y { ... } # rule1
// p[k] = v { ... } # rule2
// q { ... } # rule3
// p[x] = y { ... } # rule1
// p[k] = v { ... } # rule2
// q { ... } # rule3
//
// The following calls yield the rules on the right.
//
// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2]
// GetRulesWithPrefix("data.a.b.c.p.a") => nil
// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3]
// GetRulesWithPrefix("data.a.b.c.p") => [rule1, rule2]
// GetRulesWithPrefix("data.a.b.c.p.a") => nil
// GetRulesWithPrefix("data.a.b.c") => [rule1, rule2, rule3]
func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) {

node := c.RuleTree
Expand Down Expand Up @@ -581,18 +581,18 @@ func extractRules(s []util.T) (rules []*Rule) {
//
// E.g., given the following module:
//
// package a.b.c
// package a.b.c
//
// p[x] = y { q[x] = y; ... } # rule1
// q[x] = y { ... } # rule2
// p[x] = y { q[x] = y; ... } # rule1
// q[x] = y { ... } # rule2
//
// The following calls yield the rules on the right.
//
// GetRules("data.a.b.c.p") => [rule1]
// GetRules("data.a.b.c.p.x") => [rule1]
// GetRules("data.a.b.c.q") => [rule2]
// GetRules("data.a.b.c") => [rule1, rule2]
// GetRules("data.a.b.d") => nil
// GetRules("data.a.b.c.p") => [rule1]
// GetRules("data.a.b.c.p.x") => [rule1]
// GetRules("data.a.b.c.q") => [rule2]
// GetRules("data.a.b.c") => [rule1, rule2]
// GetRules("data.a.b.d") => nil
func (c *Compiler) GetRules(ref Ref) (rules []*Rule) {

set := map[*Rule]struct{}{}
Expand Down Expand Up @@ -627,34 +627,34 @@ func (c *Compiler) GetRulesDynamic(ref Ref) []*Rule {
//
// E.g., given the following modules:
//
// package a.b.c
// package a.b.c
//
// r1 = 1 # rule1
// r1 = 1 # rule1
//
// and:
//
// package a.d.c
// package a.d.c
//
// r2 = 2 # rule2
// r2 = 2 # rule2
//
// The following calls yield the rules on the right.
//
// GetRulesDynamicWithOpts("data.a[x].c[y]", opts) => [rule1, rule2]
// GetRulesDynamicWithOpts("data.a[x].c.r2", opts) => [rule2]
// GetRulesDynamicWithOpts("data.a.b[x][y]", opts) => [rule1]
// GetRulesDynamicWithOpts("data.a[x].c[y]", opts) => [rule1, rule2]
// GetRulesDynamicWithOpts("data.a[x].c.r2", opts) => [rule2]
// GetRulesDynamicWithOpts("data.a.b[x][y]", opts) => [rule1]
//
// Using the RulesOptions parameter, the inclusion of hidden modules can be
// controlled:
//
// With
//
// package system.main
// package system.main
//
// r3 = 3 # rule3
// r3 = 3 # rule3
//
// We'd get this result:
//
// GetRulesDynamicWithOpts("data[x]", RulesOptions{IncludeHiddenModules: true}) => [rule1, rule2, rule3]
// GetRulesDynamicWithOpts("data[x]", RulesOptions{IncludeHiddenModules: true}) => [rule1, rule2, rule3]
//
// Without the options, it would be excluded.
func (c *Compiler) GetRulesDynamicWithOpts(ref Ref, opts RulesOptions) []*Rule {
Expand Down Expand Up @@ -1053,15 +1053,28 @@ func mergeSchemas(schemas ...*gojsonschema.SubSchema) (*gojsonschema.SubSchema,
return result, nil
}

func parseSchema(schema interface{}) (types.Type, error) {
type schemaParser struct {
definitionCache map[string][]*types.StaticProperty
}

func newSchemaParser() *schemaParser {
return &schemaParser{
definitionCache: map[string][]*types.StaticProperty{},
}
}

func (parser *schemaParser) parseSchema(schema interface{}) (types.Type, error) {
subSchema, ok := schema.(*gojsonschema.SubSchema)
if !ok {
return nil, fmt.Errorf("unexpected schema type %v", subSchema)
}

// Handle referenced schemas, returns directly when a $ref is found
if subSchema.RefSchema != nil {
return parseSchema(subSchema.RefSchema)
if existing, ok := parser.definitionCache[subSchema.RefSchema.ID.String()]; ok {
return types.NewObject(existing, nil), nil
}
return parser.parseSchema(subSchema.RefSchema)
}

// Handle anyOf
Expand All @@ -1073,7 +1086,7 @@ func parseSchema(schema interface{}) (types.Type, error) {
copySchema := *subSchema
copySchemaRef := &copySchema
copySchemaRef.AnyOf = nil
coreType, err := parseSchema(copySchemaRef)
coreType, err := parser.parseSchema(copySchemaRef)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v: %w", subSchema, err)
}
Expand All @@ -1088,7 +1101,7 @@ func parseSchema(schema interface{}) (types.Type, error) {

// Iterate through every property of AnyOf and add it to orType
for _, pSchema := range subSchema.AnyOf {
newtype, err := parseSchema(pSchema)
newtype, err := parser.parseSchema(pSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err)
}
Expand All @@ -1111,12 +1124,12 @@ func parseSchema(schema interface{}) (types.Type, error) {
if err != nil {
return nil, err
}
return parseSchema(objectOrArrayResult)
return parser.parseSchema(objectOrArrayResult)
} else if subSchema.Types.String() != allOfResult.Types.String() {
return nil, fmt.Errorf("unable to merge these schemas")
}
}
return parseSchema(allOfResult)
return parser.parseSchema(allOfResult)
}

if subSchema.Types.IsTyped() {
Expand All @@ -1133,11 +1146,17 @@ func parseSchema(schema interface{}) (types.Type, error) {
if len(subSchema.PropertiesChildren) > 0 {
staticProps := make([]*types.StaticProperty, 0, len(subSchema.PropertiesChildren))
for _, pSchema := range subSchema.PropertiesChildren {
newtype, err := parseSchema(pSchema)
staticProps = append(staticProps, types.NewStaticProperty(pSchema.Property, nil))
}
if subSchema.Parent != nil {
parser.definitionCache[subSchema.ID.String()] = staticProps
}
for i, pSchema := range subSchema.PropertiesChildren {
newtype, err := parser.parseSchema(pSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v: %w", pSchema, err)
}
staticProps = append(staticProps, types.NewStaticProperty(pSchema.Property, newtype))
staticProps[i].Value = newtype
}
return types.NewObject(staticProps, nil), nil
}
Expand All @@ -1147,7 +1166,7 @@ func parseSchema(schema interface{}) (types.Type, error) {
if len(subSchema.ItemsChildren) > 0 {
if subSchema.ItemsChildrenIsSingleSchema {
iSchema := subSchema.ItemsChildren[0]
newtype, err := parseSchema(iSchema)
newtype, err := parser.parseSchema(iSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v", iSchema)
}
Expand All @@ -1156,7 +1175,7 @@ func parseSchema(schema interface{}) (types.Type, error) {
newTypes := make([]types.Type, 0, len(subSchema.ItemsChildren))
for i := 0; i != len(subSchema.ItemsChildren); i++ {
iSchema := subSchema.ItemsChildren[i]
newtype, err := parseSchema(iSchema)
newtype, err := parser.parseSchema(iSchema)
if err != nil {
return nil, fmt.Errorf("unexpected schema type %v", iSchema)
}
Expand All @@ -1171,11 +1190,11 @@ func parseSchema(schema interface{}) (types.Type, error) {
// Assume types if not specified in schema
if len(subSchema.PropertiesChildren) > 0 {
if err := subSchema.Types.Add("object"); err == nil {
return parseSchema(subSchema)
return parser.parseSchema(subSchema)
}
} else if len(subSchema.ItemsChildren) > 0 {
if err := subSchema.Types.Add("array"); err == nil {
return parseSchema(subSchema)
return parser.parseSchema(subSchema)
}
}

Expand Down Expand Up @@ -1565,11 +1584,11 @@ func checkVoidCalls(env *TypeEnv, x interface{}) Errors {
//
// For example, given the following print statement:
//
// print("the value of x is:", input.x)
// print("the value of x is:", input.x)
//
// The expression would be rewritten to:
//
// print({__local0__ | __local0__ = "the value of x is:"}, {__local1__ | __local1__ = input.x})
// print({__local0__ | __local0__ = "the value of x is:"}, {__local1__ | __local1__ = input.x})
func rewritePrintCalls(gen *localVarGenerator, getArity func(Ref) int, globals VarSet, body Body) Errors {

var errs Errors
Expand Down
Loading

0 comments on commit bac22aa

Please sign in to comment.