Skip to content

Commit

Permalink
Adding future.strict keyword
Browse files Browse the repository at this point in the history
Fixes: open-policy-agent#6247
Signed-off-by: Johan Fylling <johan.dev@fylling.se>
  • Loading branch information
johanfylling authored and ashutosh-narkar committed Oct 17, 2023
1 parent 9a1a427 commit e64e3a1
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 5 deletions.
18 changes: 18 additions & 0 deletions ast/internal/scanner/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type Scanner struct {
width int
errors []Error
keywords map[string]tokens.Token
strict bool
}

// Error represents a scanner error.
Expand Down Expand Up @@ -102,6 +103,23 @@ func (s *Scanner) AddKeyword(kw string, tok tokens.Token) {
}
}

func (s *Scanner) HasKeyword(keywords map[string]tokens.Token) bool {
for kw := range s.keywords {
if _, ok := keywords[kw]; ok {
return true
}
}
return false
}

func (s *Scanner) SetStrict() {
s.strict = true
}

func (s *Scanner) Strict() bool {
return s.strict
}

// WithKeywords returns a new copy of the Scanner struct `s`, with the set
// of known keywords being that of `s` with `kws` added.
func (s *Scanner) WithKeywords(kws map[string]tokens.Token) *Scanner {
Expand Down
60 changes: 57 additions & 3 deletions ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,21 @@ func (p *Parser) Parse() ([]Statement, []*Comment, Errors) {
break
}

// Strict check
if p.s.s.Strict() {
for _, stmt := range stmts {
// TODO: Deal with "naked" ast.Body
if rule, ok := stmt.(*Rule); ok {
if rule.Body != nil && !ruleComposedWithKeyword(rule, tokens.If) {
p.error(rule.Location, "`if` keyword is required before rule body")
}
if rule.Head.RuleKind() == MultiValue && !ruleComposedWithKeyword(rule, tokens.Contains) {
p.error(rule.Location, "`contains` keyword is required for partial set rules")
}
}
}
}

if p.po.ProcessAnnotation {
stmts = p.parseAnnotations(stmts)
}
Expand All @@ -382,6 +397,15 @@ func (p *Parser) Parse() ([]Statement, []*Comment, Errors) {
return stmts, p.s.comments, p.s.errors
}

func ruleComposedWithKeyword(rule *Rule, keyword tokens.Token) bool {
for _, kw := range rule.Head.keywords {
if kw == keyword {
return true
}
}
return false
}

func (p *Parser) parseAnnotations(stmts []Statement) []Statement {

annotStmts, errs := parseAnnotations(p.s.comments)
Expand Down Expand Up @@ -581,6 +605,10 @@ func (p *Parser) parseRules() []*Rule {
return nil
}

if usesContains {
rule.Head.keywords = append(rule.Head.keywords, tokens.Contains)
}

if rule.Default {
if !p.validateDefaultRuleValue(&rule) {
return nil
Expand All @@ -599,6 +627,10 @@ func (p *Parser) parseRules() []*Rule {
// back-compat with `p[x] { ... }``
hasIf := p.s.tok == tokens.If

if hasIf {
rule.Head.keywords = append(rule.Head.keywords, tokens.If)
}

// p[x] if ... becomes a single-value rule p[x]
if hasIf && !usesContains && len(rule.Head.Ref()) == 2 {
if rule.Head.Value == nil {
Expand Down Expand Up @@ -660,7 +692,7 @@ func (p *Parser) parseRules() []*Rule {
p.scan()

case usesContains:
rule.Body = NewBody(NewExpr(BooleanTerm(true).SetLocation(rule.Location)).SetLocation(rule.Location))
//rule.Body = NewBody(NewExpr(BooleanTerm(true).SetLocation(rule.Location)).SetLocation(rule.Location))
return []*Rule{&rule}

default:
Expand Down Expand Up @@ -2497,8 +2529,8 @@ var futureKeywords = map[string]tokens.Token{
func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]tokens.Token) {
path := imp.Path.Value.(Ref)

if len(path) == 1 || !path[1].Equal(StringTerm("keywords")) {
p.errorf(imp.Path.Location, "invalid import, must be `future.keywords`")
if len(path) == 1 || (!path[1].Equal(StringTerm("keywords")) && !path[1].Equal(StringTerm("strict"))) {
p.errorf(imp.Path.Location, "invalid import, must be `future.keywords` or `future.strict`")
return
}

Expand All @@ -2512,6 +2544,28 @@ func (p *Parser) futureImport(imp *Import, allowedFutureKeywords map[string]toke
kwds = append(kwds, k)
}

if path[1].Equal(StringTerm("strict")) {
if len(path) > 2 {
p.errorf(imp.Path.Location, "invalid import, must be `future.strict`")
return
}
if p.s.s.HasKeyword(futureKeywords) && !p.s.s.Strict() {
// We have imported future keywords, but they didn't come from another `future.strict` import.
p.errorf(imp.Path.Location, "the `future.strict` import implies `future.keywords`, these are therefore mutually exclusive")
return
}
p.s.s.SetStrict()
for _, kw := range kwds {
p.s.s.AddKeyword(kw, allowedFutureKeywords[kw])
}
return
}

if p.s.s.Strict() {
p.errorf(imp.Path.Location, "the `future.strict` import implies `future.keywords`, these are therefore mutually exclusive")
return
}

switch len(path) {
case 2: // all keywords imported, nothing to do
case 3: // one keyword imported
Expand Down
6 changes: 6 additions & 0 deletions ast/parser_ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,9 @@ func ParseStatementsWithOpts(filename, input string, popts ParserOptions) ([]Sta
return stmts, comments, nil
}

// TODO: Put this somewhere appropriate
var FutureStrictRef = MustParseRef("future.strict")

func parseModule(filename string, stmts []Statement, comments []*Comment) (*Module, error) {

if len(stmts) == 0 {
Expand All @@ -661,6 +664,9 @@ func parseModule(filename string, stmts []Statement, comments []*Comment) (*Modu
switch stmt := stmt.(type) {
case *Import:
mod.Imports = append(mod.Imports, stmt)
if Compare(stmt.Path.Value, FutureStrictRef) == 0 {
mod.strict = true
}
case *Rule:
setRuleModule(stmt, mod)
mod.Rules = append(mod.Rules, stmt)
Expand Down
161 changes: 159 additions & 2 deletions ast/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1236,13 +1236,17 @@ func TestFutureImports(t *testing.T) {
assertParseErrorContains(t, "unknown keyword", "import future.keywords.xyz", "unexpected keyword, must be one of [contains every if in]")
assertParseErrorContains(t, "all keyword import + alias", "import future.keywords as xyz", "`future` imports cannot be aliased")
assertParseErrorContains(t, "keyword import + alias", "import future.keywords.in as xyz", "`future` imports cannot be aliased")
assertParseErrorContains(t, "future.strict.abc", "import future.strict.abc", "invalid import, must be `future.strict`")

assertParseImport(t, "import kw with kw in options",
"import future.keywords.in", &Import{Path: RefTerm(VarTerm("future"), StringTerm("keywords"), StringTerm("in"))},
ParserOptions{FutureKeywords: []string{"in"}})
assertParseImport(t, "import kw with all kw in options",
"import future.keywords.in", &Import{Path: RefTerm(VarTerm("future"), StringTerm("keywords"), StringTerm("in"))},
ParserOptions{AllFutureKeywords: true})
assertParseImport(t, "import strict",
"import future.strict", &Import{Path: RefTerm(VarTerm("future"), StringTerm("strict"))},
ParserOptions{})

mod := `
package p
Expand All @@ -1258,6 +1262,13 @@ func TestFutureImports(t *testing.T) {
}
assertParseModule(t, "multiple imports, all kw in options", mod, &parsed, ParserOptions{AllFutureKeywords: true})
assertParseModule(t, "multiple imports, single in options", mod, &parsed, ParserOptions{FutureKeywords: []string{"in"}})

mod = `
package p
import future.strict
import future.keywords.in
`
assertParseModuleErrorMatch(t, "strict and keywords imported", mod, "rego_parse_error: the `future.strict` import implies `future.keywords`, these are therefore mutually exclusive")
}

func TestFutureImportsExtraction(t *testing.T) {
Expand All @@ -1275,14 +1286,34 @@ func TestFutureImportsExtraction(t *testing.T) {
{
note: "all keywords imported",
imp: "import future.keywords",
exp: map[string]tokens.Token{"in": tokens.In},
exp: map[string]tokens.Token{
"in": tokens.In,
"every": tokens.Every,
"contains": tokens.Contains,
"if": tokens.If,
},
},
{
note: "all keywords + single keyword imported",
imp: `
import future.keywords
import future.keywords.in`,
exp: map[string]tokens.Token{"in": tokens.In},
exp: map[string]tokens.Token{
"in": tokens.In,
"every": tokens.Every,
"contains": tokens.Contains,
"if": tokens.If,
},
},
{
note: "future.strict imported",
imp: "import future.strict",
exp: map[string]tokens.Token{
"in": tokens.In,
"every": tokens.Every,
"contains": tokens.Contains,
"if": tokens.If,
},
},
}
for _, tc := range tests {
Expand All @@ -1302,6 +1333,121 @@ func TestFutureImportsExtraction(t *testing.T) {
}
}

func TestFutureStrictImport(t *testing.T) {
tests := []struct {
note string
module string
expectedErrors []string
}{
{
note: "only future.strict imported",
module: `package test
import future.strict
p contains 1 if 1 == 1`,
},
{
note: "future.strict and future.keywords imported",
module: `package test
import future.strict
import future.keywords
p contains 1 if {
input.x == 1
}`,
expectedErrors: []string{"rego_parse_error: the `future.strict` import implies `future.keywords`, these are therefore mutually exclusive"},
},
{
note: "`if` keyword used on rule",
module: `package test
import future.strict
p if {
input.x == 1
}`,
},
{
note: "`if` keyword not used on rule",
module: `package test
import future.strict
p {
input.x == 1
}`,
expectedErrors: []string{"rego_parse_error: `if` keyword is required before rule body"},
},
{
note: "`if` keyword not used on constant definition",
module: `package test
import future.strict
p := 1`,
},
// TODO: "`if` keyword used before else body"
// TODO: "`if` keyword not used before else body"
{
note: "`contains` keyword used on partial set rule (const key)",
module: `package test
import future.strict
p contains "q"`,
},
{ // FIXME: need to deal with "naked" statements in the parser
note: "`contains` keyword not used on partial set rule (const key)",
module: `package test
import future.strict
p.q`,
expectedErrors: []string{"rego_parse_error: `contains` keyword is required for partial set rules"},
},
{
note: "`contains` keyword used on partial set rule (var key, no body)",
module: `package test
import future.strict
p contains input.x`,
},
{
note: "`contains` keyword not used on partial set rule (var key, no body)",
module: `package test
import future.strict
p[input.x]`,
expectedErrors: []string{"rego_parse_error: `contains` keyword is required for partial set rules"},
},
{
note: "`contains` keyword used on partial set rule (var key)",
module: `package test
import future.strict
p contains x if { x = input.x}`,
},
{
note: "`if` keyword used on partial map rule (would be multi-value without `if`)",
module: `package test
import future.strict
p[x] if { x = input.x}`,
},
{
note: "`contains` and `if` keyword not used on partial rule",
module: `package test
import future.strict
p[x] { x = input.x}`,
// The developer likely intended a partial set.
expectedErrors: []string{
"rego_parse_error: `contains` keyword is required for partial set rules",
"rego_parse_error: `if` keyword is required before rule body",
},
},
}

for _, tc := range tests {
t.Run(tc.note, func(t *testing.T) {
parser := NewParser().WithFilename("").WithReader(bytes.NewBufferString(tc.module))
_, _, errs := parser.Parse()
if len(tc.expectedErrors) == 0 && len(errs) > 0 {
t.Fatalf("expected no errors, got:\n\n%v", errs)
}
actual := errs.Error()
for _, expected := range tc.expectedErrors {
if !strings.Contains(actual, expected) {
t.Errorf("expected error:\n\n%q\n\ngot:\n\n%v", expected, actual)
}
}
})
}
}

func TestIsValidImportPath(t *testing.T) {
tests := []struct {
path string
Expand Down Expand Up @@ -5174,6 +5320,17 @@ func assertParseModuleError(t *testing.T, msg, input string) {
}
}

func assertParseModuleErrorMatch(t *testing.T, msg, input string, expected string) {
t.Helper()
m, err := ParseModule("", input)
if err == nil {
t.Errorf("Error on test \"%s\": expected parse error: %v (parsed)", msg, m)
}
if !strings.Contains(err.Error(), expected) {
t.Errorf("Error on test \"%s\"; expected:\n\n%v\n\ngot:\n\n%v", msg, expected, err)
}
}

func assertParsePackage(t *testing.T, msg string, input string, correct *Package) {
assertParseOne(t, msg, input, func(parsed interface{}) {
pkg := parsed.(*Package)
Expand Down
4 changes: 4 additions & 0 deletions ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/open-policy-agent/opa/ast/internal/tokens"
"math/rand"
"strings"
"time"
Expand Down Expand Up @@ -146,6 +147,7 @@ type (
Rules []*Rule `json:"rules,omitempty"`
Comments []*Comment `json:"comments,omitempty"`
stmts []Statement
strict bool
}

// Comment contains the raw text from the comment in the definition.
Expand Down Expand Up @@ -203,6 +205,8 @@ type (
Value *Term `json:"value,omitempty"`
Assign bool `json:"assign,omitempty"`
Location *Location `json:"location,omitempty"`
// FIXME: add Keyword type?
keywords []tokens.Token // TODO: add to JSON serialization?

jsonOptions astJSON.Options
}
Expand Down

0 comments on commit e64e3a1

Please sign in to comment.