From 68f5198a22defc6609a5e3dc35c1ea936bf725e0 Mon Sep 17 00:00:00 2001 From: Aman Mangal Date: Thu, 11 Jul 2019 14:42:41 +0530 Subject: [PATCH] Add support for lexing, parsing Conditional Upsert --- gql/mutation.go | 3 + gql/parser_mutation.go | 61 +++++++++-- gql/state.go | 87 ++++++++++++++-- gql/upsert_test.go | 230 +++++++++++++++++++++++++++++++++++++++++ lex/lexer.go | 1 + 5 files changed, 364 insertions(+), 18 deletions(-) diff --git a/gql/mutation.go b/gql/mutation.go index 55e7f7839c2..aecaca6db94 100644 --- a/gql/mutation.go +++ b/gql/mutation.go @@ -34,6 +34,9 @@ var ( type Mutation struct { Set []*api.NQuad Del []*api.NQuad + + // CondTree stores the condition of mutation (@if directive) + CondTree *FilterTree } // ParseUid parses the given string into an UID. This method returns with an error diff --git a/gql/parser_mutation.go b/gql/parser_mutation.go index 1fac61ba14f..abd65a1aa8d 100644 --- a/gql/parser_mutation.go +++ b/gql/parser_mutation.go @@ -19,9 +19,43 @@ package gql import ( "github.com/dgraph-io/dgo/protos/api" "github.com/dgraph-io/dgraph/lex" + "github.com/pkg/errors" ) +// ParseIfDirective parses the if directive into a FilterTree +func ParseIfDirective(cond string) (*FilterTree, error) { + if cond == "" { + return nil, nil + } + + lexer := lex.NewLexer(cond) + lexer.Run(lexIfDirective) + if err := lexer.ValidateResult(); err != nil { + return nil, err + } + + // ===>@<=== if(...) + it := lexer.NewIterator() + if !it.Next() { + return nil, errors.Errorf("Empty if directive") + } + if item := it.Item(); item.Typ != itemAt { + return nil, errors.Errorf("Expected @, found [%v]", item.Val) + } + + // @ ===>if<=== (...) + if !it.Next() { + return nil, errors.Errorf("Invalid if directive") + } + if item := it.Item(); item.Typ != itemName || item.Val != "if" { + return nil, errors.Errorf("Expected if, found [%v]", item.Val) + } + + // @if ===>(...)<=== + return parseFilter(it) +} + // ParseMutation parses a block into a mutation. Returns an object with a mutation or // an upsert block with mutation, otherwise returns nil with an error. func ParseMutation(mutation string) (mu *api.Mutation, err error) { @@ -39,11 +73,11 @@ func ParseMutation(mutation string) (mu *api.Mutation, err error) { item := it.Item() switch item.Typ { case itemUpsertBlock: - if mu, err = ParseUpsertBlock(it); err != nil { + if mu, err = parseUpsertBlock(it); err != nil { return nil, err } case itemLeftCurl: - if mu, err = ParseMutationBlock(it); err != nil { + if mu, err = parseMutationBlock(it); err != nil { return nil, err } default: @@ -58,10 +92,11 @@ func ParseMutation(mutation string) (mu *api.Mutation, err error) { return mu, nil } -// ParseUpsertBlock parses the upsert block -func ParseUpsertBlock(it *lex.ItemIterator) (*api.Mutation, error) { +// parseUpsertBlock parses the upsert block +func parseUpsertBlock(it *lex.ItemIterator) (*api.Mutation, error) { var mu *api.Mutation var queryText string + var condText string var queryFound bool // ===>upsert<=== {...} @@ -86,6 +121,7 @@ func ParseUpsertBlock(it *lex.ItemIterator) (*api.Mutation, error) { return nil, errors.Errorf("Query op not found in upsert block") } else { mu.Query = queryText + mu.Cond = condText return mu, nil } @@ -109,8 +145,19 @@ func ParseUpsertBlock(it *lex.ItemIterator) (*api.Mutation, error) { if !it.Next() { return nil, errors.Errorf("Unexpected end of upsert block") } + + // upsert { mutation ===>@if(...)<=== {....} query{...}} + item = it.Item() + if item.Typ == itemUpsertBlockOpContent { + condText = item.Val + if !it.Next() { + return nil, errors.Errorf("Unexpected end of upsert block") + } + } + + // upsert @if(...) ===>{<=== ....} var err error - if mu, err = ParseMutationBlock(it); err != nil { + if mu, err = parseMutationBlock(it); err != nil { return nil, err } @@ -133,8 +180,8 @@ func ParseUpsertBlock(it *lex.ItemIterator) (*api.Mutation, error) { return nil, errors.Errorf("Invalid upsert block") } -// ParseMutationBlock parses the mutation block -func ParseMutationBlock(it *lex.ItemIterator) (*api.Mutation, error) { +// parseMutationBlock parses the mutation block +func parseMutationBlock(it *lex.ItemIterator) (*api.Mutation, error) { var mu api.Mutation item := it.Item() diff --git a/gql/state.go b/gql/state.go index e1d12645065..c0fb729dd36 100644 --- a/gql/state.go +++ b/gql/state.go @@ -65,6 +65,34 @@ const ( itemMathOp ) +// lexIfDirective lexes the @if directive in a mutation block +func lexIfDirective(l *lex.Lexer) lex.StateFn { + l.Mode = lexIfDirective + for { + switch r := l.Next(); { + case r == lex.EOF: + l.Emit(lex.ItemEOF) + return nil + case isSpace(r) || lex.IsEndOfLine(r): + l.Ignore() + case r == '#': + return lexComment + case r == leftRound: + l.Emit(itemLeftRound) + l.AcceptRun(isSpace) + l.Ignore() + l.ArgDepth++ + l.WhetherIf = true + return lexFuncOrArg + case r == at: + l.Emit(itemAt) + return lexDirectiveOrLangList + default: + return l.Errorf("Unrecognized character in lexText: %#U", r) + } + } +} + // lexIdentifyBlock identifies whether it is an upsert block // If the block begins with "{" => mutation block // Else if the block begins with "upsert" => upsert block @@ -93,11 +121,7 @@ func lexIdentifyBlock(l *lex.Lexer) lex.StateFn { func lexNameBlock(l *lex.Lexer) lex.StateFn { for { // The caller already checked isNameBegin, and absorbed one rune. - r := l.Next() - if isNameSuffix(r) { - continue - } - l.Backup() + l.AcceptRun(isNameSuffix) switch word := l.Input[l.Start:l.Pos]; word { case "upsert": l.Emit(itemUpsertBlock) @@ -140,11 +164,7 @@ func lexUpsertBlock(l *lex.Lexer) lex.StateFn { func lexNameUpsertOp(l *lex.Lexer) lex.StateFn { for { // The caller already checked isNameBegin, and absorbed one rune. - r := l.Next() - if isNameSuffix(r) { - continue - } - l.Backup() + l.AcceptRun(isNameSuffix) word := l.Input[l.Start:l.Pos] switch word { case "query": @@ -187,10 +207,48 @@ func lexBlockContent(l *lex.Lexer) lex.StateFn { } } +// lexIfContent lexes the whole of @if directive in a mutation block +func lexIfContent(l *lex.Lexer) lex.StateFn { + if r := l.Next(); r != at { + return l.Errorf("Expected [@], found; [%#U]", r) + } + + l.AcceptRun(isNameSuffix) + word := l.Input[l.Start:l.Pos] + if word != "@if" { + return l.Errorf("Expected @if, found [%v]", word) + } + + depth := 0 + for { + switch l.Next() { + case lex.EOF: + return l.Errorf("Invalid if directive") + case quote: + if err := l.LexQuotedString(); err != nil { + return l.Errorf(err.Error()) + } + case leftRound: + depth++ + case rightRound: + depth-- + if depth < 0 { + return l.Errorf("Unopened ) found in if directive") + } else if depth == 0 { + l.Emit(itemUpsertBlockOpContent) + return lexInsideMutation + } + } + } +} + func lexInsideMutation(l *lex.Lexer) lex.StateFn { l.Mode = lexInsideMutation for { switch r := l.Next(); { + case r == at: + l.Backup() + return lexIfContent case r == rightCurl: l.Depth-- l.Emit(itemRightCurl) @@ -304,6 +362,13 @@ func lexFuncOrArg(l *lex.Lexer) lex.StateFn { return l.Errorf("Empty Argument") } if l.ArgDepth == 0 { + // TODO(Aman): We should make l.Mode a stack instead + // and avoid such conditions as below. + if l.WhetherIf { + l.WhetherIf = false + return lexIfDirective + } + return lexQuery // Filter directive is done. } case r == lex.EOF: @@ -586,7 +651,7 @@ func lexOperationType(l *lex.Lexer) lex.StateFn { l.Emit(itemOpType) return lexInsideSchema } else { - l.Errorf("Invalid operation type: %s", word) + return l.Errorf("Invalid operation type: %s", word) } return lexQuery diff --git a/gql/upsert_test.go b/gql/upsert_test.go index de02e6c53e4..6952377e9ba 100644 --- a/gql/upsert_test.go +++ b/gql/upsert_test.go @@ -312,3 +312,233 @@ func TestUpsertWithFilter(t *testing.T) { _, err := ParseMutation(query) require.Nil(t, err) } + +func TestConditionalUpsertWithNewlines(t *testing.T) { + query := `upsert { + mutation @if(eq(count(m), 1) + AND + gt(count(f), 0)) { + set { + uid(m) "45" . + uid(f) "45" . + } + } + + query { + me(func: eq(age, 34)) @filter(ge(name, "user")) { + uid + friend { + uid + age + } + } + } +} +` + mu, err := ParseMutation(query) + require.Nil(t, err) + _, err = ParseIfDirective(mu.Cond) + require.Nil(t, err) +} + +func TestConditionalUpsertFuncTree(t *testing.T) { + query := `upsert { + mutation @if( ( eq(count(m), 1) + OR + lt(90, count(h))) + AND + gt(count(f), 0)) { + set { + uid(m) "45" . + uid(f) "45" . + } + } + + query { + me(func: eq(age, 34)) @filter(ge(name, "user")) { + uid + friend { + uid + age + } + } + } +} +` + mu, err := ParseMutation(query) + require.Nil(t, err) + _, err = ParseIfDirective(mu.Cond) + require.Nil(t, err) +} + +func TestConditionalUpsertMultipleFuncArg(t *testing.T) { + query := `upsert { + mutation @if( ( eq(count(m), count(t)) + OR + lt(90, count(h))) + AND + gt(count(f), 0)) { + set { + uid(m) "45" . + uid(f) "45" . + } + } + + query { + me(func: eq(age, 34)) @filter(ge(name, "user")) { + uid + friend { + uid + age + } + } + } +} +` + mu, err := ParseMutation(query) + require.Nil(t, err) + _, err = ParseIfDirective(mu.Cond) + require.Contains(t, err.Error(), "Multiple functions as arguments not allowed") +} + +func TestConditionalUpsertErrMissingRightRound(t *testing.T) { + query := `upsert { + mutation @if(eq(len(m, 1) + AND + gt(len(f), 0)) { + set { + uid(m) "45" . + uid(f) "45" . + } + } + + query { + me(func: eq(age, 34)) @filter(ge(name, "user")) { + uid + friend { + uid + age + } + } + } +} +` + _, err := ParseMutation(query) + require.Contains(t, err.Error(), "Invalid if directive") +} + +func TestConditionalUpsertErrUnclosed(t *testing.T) { + query := `upsert { + mutation @if(eq(len(m), 1) AND gt(len(f), 0))` + _, err := ParseMutation(query) + require.Contains(t, err.Error(), "Unclosed mutation action") +} + +func TestConditionalUpsertErrInvalidIf(t *testing.T) { + query := `upsert { + mutation @if` + _, err := ParseMutation(query) + require.Contains(t, err.Error(), "Invalid if directive") +} + +func TestConditionalUpsertErrWrongIf(t *testing.T) { + query := `upsert { + mutation @fi( ( eq(len(m), 1) + OR + lt(len(h), 90)) + AND + gt(len(f), 0)) { + set { + uid(m) "45" . + uid(f) "45" . + } + } + + query { + me(func: eq(age, 34)) @filter(ge(name, "user")) { + uid + friend { + uid + age + } + } + } +} +` + _, err := ParseMutation(query) + require.Contains(t, err.Error(), "Expected @if, found [@fi]") +} + +func TestIfDirectiveNoIf(t *testing.T) { + ft, err := ParseIfDirective("") + require.Nil(t, ft) + require.Nil(t, err) +} + +func TestIfDirectiveWithComment(t *testing.T) { + cond := ` @if( ( eq(count(m), 1) + # This is a comment + OR + lt(90, count(h))) + AND + gt(count(f), 0))` + _, err := ParseIfDirective(cond) + require.Nil(t, err) +} + +func TestIfDirectiveWithSameLineComment(t *testing.T) { + cond := ` @if( ( eq(count(m), 1) + OR # This is another comment + lt(90, count(h)) ) + AND + gt(count(f), 0))` + _, err := ParseIfDirective(cond) + require.Nil(t, err) +} + +func TestIfDirectiveWithAfterIfComment(t *testing.T) { + cond := ` @if # This comment is okay too + ((eq(count(m), 1) + OR # This is another comment + lt(90, count(h)) ) + AND + gt(count(f), 0)) ` + _, err := ParseIfDirective(cond) + require.Nil(t, err) +} + +func TestIfDirectiveErrMissingAt(t *testing.T) { + cond := ` if( ( eq(count(m), 1) + OR + lt(90, count(h))) + AND + gt(count(f), 0))` + _, err := ParseIfDirective(cond) + require.Contains(t, err.Error(), "Unrecognized character in lexText: U+0069 'i'") +} + +func TestIfDirectiveErrWhiteSpace(t *testing.T) { + cond := ` ` + _, err := ParseIfDirective(cond) + require.Contains(t, err.Error(), "Expected @, found []") +} + +func TestIfDirectiveNoParseErr(t *testing.T) { + cond := ` @if( ( eq(count(m), 1) + OR + lt(90, count(h))) + AND + gt(count(h), 0)) @if()` + _, err := ParseIfDirective(cond) + require.Nil(t, err) +} + +func TestIfDirectiveErrMissingRightRound(t *testing.T) { + cond := `@if( ( eq(count(m), 1) + OR + lt(90, count(h))) + AND + gt(count(h), 0)` + _, err := ParseIfDirective(cond) + require.Contains(t, err.Error(), "Unclosed Brackets") +} diff --git a/lex/lexer.go b/lex/lexer.go index fae1d1d913c..359cbe95e93 100644 --- a/lex/lexer.go +++ b/lex/lexer.go @@ -172,6 +172,7 @@ type Lexer struct { items []Item // channel of scanned items. Depth int // nesting of {} BlockDepth int // nesting of blocks (e.g. mutation block inside upsert block) + WhetherIf bool // Used to figure out where to return go from lexFuncOrArg ArgDepth int // nesting of () Mode StateFn // Default state to go back to after reading a token. Line int // the current line number corresponding to Start