diff --git a/dgraph/cmd/alpha/http.go b/dgraph/cmd/alpha/http.go index 080ff55d9be..c91b973fa87 100644 --- a/dgraph/cmd/alpha/http.go +++ b/dgraph/cmd/alpha/http.go @@ -326,6 +326,13 @@ func mutationHandler(w http.ResponseWriter, r *http.Request) { return } } + if condText, ok := ms["cond"]; ok && condText != nil { + mu.Query, err = strconv.Unquote(string(condText.bs)) + if err != nil { + x.SetStatus(w, x.ErrorInvalidRequest, err.Error()) + return + } + } case "application/rdf": // Parse N-Quads. diff --git a/edgraph/server.go b/edgraph/server.go index 6a6a5b3df19..d151f3a1af1 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -576,43 +576,62 @@ func doQueryInUpsert(ctx context.Context, mu *api.Mutation, gmu *gql.Mutation) ( return l, nil } + queryWithIf := mu.Query + if strings.TrimSpace(mu.Cond) != "" { + cond := strings.Replace(mu.Cond, "@if", "@filter", 1) + queryWithIf = mu.Query + `\n var(func: uid(0x01)) ` + cond + ` {_uid_ as uid}` + } + needVars := findVars(gmu) + needVars = append(needVars, "_uid_") startParsingTime := time.Now() parsedReq, err := gql.ParseWithNeedVars(gql.Request{ - Str: mu.Query, + Str: queryWithIf, Variables: make(map[string]string), }, needVars) l.Parsing += time.Since(startParsingTime) if err != nil { - return nil, errors.Wrapf(err, "while parsing query: %q", mu.Query) + return nil, errors.Wrapf(err, "while parsing query: %q", queryWithIf) } if err := validateQuery(parsedReq.Query); err != nil { - return nil, errors.Wrapf(err, "while validating query: %q", mu.Query) + return nil, errors.Wrapf(err, "while validating query: %q", queryWithIf) } qr := query.Request{Latency: l, GqlQuery: &parsedReq, ReadTs: mu.StartTs} if err := qr.ProcessQuery(ctx); err != nil { - return nil, errors.Wrapf(err, "while processing query: %q", mu.Query) + return nil, errors.Wrapf(err, "while processing query: %q", queryWithIf) } if len(qr.Vars) <= 0 { return nil, errors.Errorf("upsert query block has no variables") } - // TODO(Aman): allow multiple values for each variable. // If a variable doesn't have any UID, we generate one ourselves later. - varToUID := make(map[string]string) + varToUID := make(map[string][]string) for name, v := range qr.Vars { if v.Uids == nil { continue } - if len(v.Uids.Uids) > 1 { - return nil, errors.Errorf("more than one values found for var (%s)", name) - } else if len(v.Uids.Uids) == 1 { - varToUID[name] = fmt.Sprintf("%d", v.Uids.Uids[0]) + + if len(v.Uids.Uids) > 0 { + uids := make([]string, len(v.Uids.Uids)) + for i, u := range v.Uids.Uids { + uids[i] = fmt.Sprintf("%d", u) + } + + varToUID[name] = uids } } + // Conditional mutation, we simply return in case condition is false + v, ok := qr.Vars["_uid_"] + isMut := ok && (len(v.Uids.Uids) == 1) + if !isMut { + gmu.Set = nil + gmu.Del = nil + return l, nil + } + updateMutations(gmu, varToUID) return l, nil } @@ -649,39 +668,63 @@ func findVars(gmu *gql.Mutation) []string { // updateMutations does following transformations: // * uid(v) -> 0x123 -- If v is defined in query block // * uid(v) -> _:uid(v) -- Otherwise -func updateMutations(gmu *gql.Mutation, varToUID map[string]string) { - getNewVal := func(s string) string { +func updateMutations(gmu *gql.Mutation, varToUID map[string][]string) { + getNewVal := func(s string) []string { if strings.HasPrefix(s, "uid(") { varName := s[4 : len(s)-1] - if uid, ok := varToUID[varName]; ok { - return uid + if uids, ok := varToUID[varName]; ok { + return uids } - return "_:" + s + return []string{"_:" + s} } - return s + return []string{s} + } + + getNewNQuad := func(nq *api.NQuad, s, o string) *api.NQuad { + // The following copy is fine because we only modify Subject and ObjectId. + // The pointer values are not modified across different copies of NQuad. + n := *nq + + n.Subject = s + n.ObjectId = o + return &n } // Remove the mutations from gmu.Del when no UID was found. - gmuDel := gmu.Del[:0] + gmuDel := make([]*api.NQuad, 0, len(gmu.Del)) for _, nq := range gmu.Del { - nq.Subject = getNewVal(nq.Subject) - nq.ObjectId = getNewVal(nq.ObjectId) - - if !strings.HasPrefix(nq.Subject, "_:uid(") && - !strings.HasPrefix(nq.ObjectId, "_:uid(") { + newSubs := getNewVal(nq.Subject) + newObs := getNewVal(nq.ObjectId) + + for _, s := range newSubs { + for _, o := range newObs { + // Blank node has no meaning in case of deletion. + if strings.HasPrefix(s, "_:uid(") || + strings.HasPrefix(o, "_:uid(") { + continue + } - gmuDel = append(gmuDel, nq) + gmuDel = append(gmuDel, getNewNQuad(nq, s, o)) + } } } gmu.Del = gmuDel // Update the values in mutation block from the query block. + gmuSet := make([]*api.NQuad, 0, len(gmu.Set)) for _, nq := range gmu.Set { - nq.Subject = getNewVal(nq.Subject) - nq.ObjectId = getNewVal(nq.ObjectId) + newSubs := getNewVal(nq.Subject) + newObs := getNewVal(nq.ObjectId) + + for _, s := range newSubs { + for _, o := range newObs { + gmuSet = append(gmuSet, getNewNQuad(nq, s, o)) + } + } } + gmu.Set = gmuSet } // Query handles queries and returns the data. diff --git a/gql/parser.go b/gql/parser.go index ed9052df882..04d116e753c 100644 --- a/gql/parser.go +++ b/gql/parser.go @@ -562,7 +562,7 @@ func ParseWithNeedVars(r Request, needVars []string) (res Result, rerr error) { return res, err } - // Substitute all variables with corresponding values + // Substitute all graphql variables with corresponding values if err := substituteVariables(qu, vmap); err != nil { return res, err } diff --git a/gql/parser_mutation.go b/gql/parser_mutation.go index 1fac61ba14f..6ba2605276f 100644 --- a/gql/parser_mutation.go +++ b/gql/parser_mutation.go @@ -39,11 +39,11 @@ func ParseMutation(mutation string) (mu *api.Mutation, err error) { item := it.Item() switch item.Typ { case itemUpsertBlock: - if mu, err = ParseUpsertBlock(it); err != nil { + if mu, err = parseUpsertBlock(it); err != nil { return nil, err } case itemLeftCurl: - if mu, err = ParseMutationBlock(it); err != nil { + if mu, err = parseMutationBlock(it); err != nil { return nil, err } default: @@ -58,11 +58,11 @@ func ParseMutation(mutation string) (mu *api.Mutation, err error) { return mu, nil } -// ParseUpsertBlock parses the upsert block -func ParseUpsertBlock(it *lex.ItemIterator) (*api.Mutation, error) { +// parseUpsertBlock parses the upsert block +func parseUpsertBlock(it *lex.ItemIterator) (*api.Mutation, error) { var mu *api.Mutation - var queryText string - var queryFound bool + var queryText, condText string + var queryFound, condFound bool // ===>upsert<=== {...} if !it.Next() { @@ -86,6 +86,7 @@ func ParseUpsertBlock(it *lex.ItemIterator) (*api.Mutation, error) { return nil, errors.Errorf("Query op not found in upsert block") } else { mu.Query = queryText + mu.Cond = condText return mu, nil } @@ -109,8 +110,23 @@ func ParseUpsertBlock(it *lex.ItemIterator) (*api.Mutation, error) { if !it.Next() { return nil, errors.Errorf("Unexpected end of upsert block") } + + // upsert { mutation ===>@if(...)<=== {....} query{...}} + item = it.Item() + if item.Typ == itemUpsertBlockOpContent { + if condFound { + return nil, errors.Errorf("Multiple @if directive inside upsert block") + } + condFound = true + condText = item.Val + if !it.Next() { + return nil, errors.Errorf("Unexpected end of upsert block") + } + } + + // upsert @if(...) ===>{<=== ....} var err error - if mu, err = ParseMutationBlock(it); err != nil { + if mu, err = parseMutationBlock(it); err != nil { return nil, err } @@ -133,8 +149,8 @@ func ParseUpsertBlock(it *lex.ItemIterator) (*api.Mutation, error) { return nil, errors.Errorf("Invalid upsert block") } -// ParseMutationBlock parses the mutation block -func ParseMutationBlock(it *lex.ItemIterator) (*api.Mutation, error) { +// parseMutationBlock parses the mutation block +func parseMutationBlock(it *lex.ItemIterator) (*api.Mutation, error) { var mu api.Mutation item := it.Item() diff --git a/gql/state.go b/gql/state.go index e1d12645065..58d7532c2bb 100644 --- a/gql/state.go +++ b/gql/state.go @@ -93,11 +93,7 @@ func lexIdentifyBlock(l *lex.Lexer) lex.StateFn { func lexNameBlock(l *lex.Lexer) lex.StateFn { for { // The caller already checked isNameBegin, and absorbed one rune. - r := l.Next() - if isNameSuffix(r) { - continue - } - l.Backup() + l.AcceptRun(isNameSuffix) switch word := l.Input[l.Start:l.Pos]; word { case "upsert": l.Emit(itemUpsertBlock) @@ -140,11 +136,7 @@ func lexUpsertBlock(l *lex.Lexer) lex.StateFn { func lexNameUpsertOp(l *lex.Lexer) lex.StateFn { for { // The caller already checked isNameBegin, and absorbed one rune. - r := l.Next() - if isNameSuffix(r) { - continue - } - l.Backup() + l.AcceptRun(isNameSuffix) word := l.Input[l.Start:l.Pos] switch word { case "query": @@ -164,33 +156,56 @@ func lexNameUpsertOp(l *lex.Lexer) lex.StateFn { // lexBlockContent lexes and absorbs the text inside a block (covered by braces). func lexBlockContent(l *lex.Lexer) lex.StateFn { + return lexContent(l, leftCurl, rightCurl, lexUpsertBlock) +} + +// lexIfContent lexes the whole of @if directive in a mutation block (covered by small brackets) +func lexIfContent(l *lex.Lexer) lex.StateFn { + if r := l.Next(); r != at { + return l.Errorf("Expected [@], found; [%#U]", r) + } + + l.AcceptRun(isNameSuffix) + word := l.Input[l.Start:l.Pos] + if word != "@if" { + return l.Errorf("Expected @if, found [%v]", word) + } + + return lexContent(l, '(', ')', lexInsideMutation) +} + +func lexContent(l *lex.Lexer, leftRune, rightRune rune, returnTo lex.StateFn) lex.StateFn { depth := 0 for { switch l.Next() { case lex.EOF: - return l.Errorf("Unclosed block (matching braces not found)") + return l.Errorf("Matching brackets not found") case quote: if err := l.LexQuotedString(); err != nil { return l.Errorf(err.Error()) } - case leftCurl: + case leftRune: depth++ - case rightCurl: + case rightRune: depth-- if depth < 0 { - return l.Errorf("Unopened } found") + return l.Errorf("Unopened %s found", rightRune) } else if depth == 0 { l.Emit(itemUpsertBlockOpContent) - return lexUpsertBlock + return returnTo } } } + } func lexInsideMutation(l *lex.Lexer) lex.StateFn { l.Mode = lexInsideMutation for { switch r := l.Next(); { + case r == at: + l.Backup() + return lexIfContent case r == rightCurl: l.Depth-- l.Emit(itemRightCurl) @@ -586,7 +601,7 @@ func lexOperationType(l *lex.Lexer) lex.StateFn { l.Emit(itemOpType) return lexInsideSchema } else { - l.Errorf("Invalid operation type: %s", word) + return l.Errorf("Invalid operation type: %s", word) } return lexQuery diff --git a/gql/upsert_test.go b/gql/upsert_test.go index de02e6c53e4..95b16c1056a 100644 --- a/gql/upsert_test.go +++ b/gql/upsert_test.go @@ -312,3 +312,235 @@ func TestUpsertWithFilter(t *testing.T) { _, err := ParseMutation(query) require.Nil(t, err) } + +func TestConditionalUpsertWithNewlines(t *testing.T) { + query := `upsert { + mutation @if(eq(len(m), 1) + AND + gt(len(f), 0)) { + set { + uid(m) "45" . + uid(f) "45" . + } + } + + query { + me(func: eq(age, 34)) @filter(ge(name, "user")) { + uid + friend { + uid + age + } + } + } +} +` + mu, err := ParseMutation(query) + require.Nil(t, err) + _, err = ParseIfDirective(mu.Cond) + require.Nil(t, err) +} + +func TestConditionalUpsertFuncTree(t *testing.T) { + query := `upsert { + mutation @if( ( eq(len(m), 1) + OR + lt(90, len(h))) + AND + gt(len(f), 0)) { + set { + uid(m) "45" . + uid(f) "45" . + } + } + + query { + me(func: eq(age, 34)) @filter(ge(name, "user")) { + uid + friend { + uid + age + } + } + } +} +` + mu, err := ParseMutation(query) + require.Nil(t, err) + _, err = ParseIfDirective(mu.Cond) + require.Nil(t, err) +} + +func TestConditionalUpsertMultipleFuncArg(t *testing.T) { + query := `upsert { + mutation @if( ( eq(len(m), len(t)) + OR + lt(90, len(h))) + AND + gt(len(f), 0)) { + set { + uid(m) "45" . + uid(f) "45" . + } + } + + query { + me(func: eq(age, 34)) @filter(ge(name, "user")) { + uid + friend { + uid + age + } + } + } +} +` + mu, err := ParseMutation(query) + require.Nil(t, err) + _, err = ParseIfDirective(mu.Cond) + require.Contains(t, err.Error(), "Multiple functions as arguments not allowed") +} + +func TestConditionalUpsertErrMissingRightRound(t *testing.T) { + query := `upsert { + mutation @if(eq(len(m, 1) + AND + gt(len(f), 0)) { + set { + uid(m) "45" . + uid(f) "45" . + } + } + + query { + me(func: eq(age, 34)) @filter(ge(name, "user")) { + uid + friend { + uid + age + } + } + } +} +` + _, err := ParseMutation(query) + require.Contains(t, err.Error(), "Matching brackets not found") +} + +func TestConditionalUpsertErrUnclosed(t *testing.T) { + query := `upsert { + mutation @if(eq(len(m), 1) AND gt(len(f), 0))` + _, err := ParseMutation(query) + require.Contains(t, err.Error(), "Unclosed mutation action") +} + +func TestConditionalUpsertErrInvalidIf(t *testing.T) { + query := `upsert { + mutation @if` + _, err := ParseMutation(query) + require.Contains(t, err.Error(), "Matching brackets not found") +} + +func TestConditionalUpsertErrWrongIf(t *testing.T) { + query := `upsert { + mutation @fi( ( eq(len(m), 1) + OR + lt(len(h), 90)) + AND + gt(len(f), 0)) { + set { + uid(m) "45" . + uid(f) "45" . + } + } + + query { + me(func: eq(age, 34)) @filter(ge(name, "user")) { + uid + friend { + uid + age + } + } + } +} +` + _, err := ParseMutation(query) + require.Contains(t, err.Error(), "Expected @if, found [@fi]") +} + +func TestIfDirectiveNoIf(t *testing.T) { + ft, err := ParseIfDirective("") + require.Nil(t, ft) + require.Nil(t, err) +} + +func TestIfDirectiveWithComment(t *testing.T) { + cond := ` @if( ( eq(len(m), 1) + # This is a comment + OR + lt(90, len(h))) + AND + gt(len(f), 0))` + _, err := ParseIfDirective(cond) + require.Nil(t, err) +} + +func TestIfDirectiveWithSameLineComment(t *testing.T) { + cond := ` @if( ( eq(len(m), 1) + OR # This is another comment + lt(90, len(h)) ) + AND + gt(len(f), 0))` + _, err := ParseIfDirective(cond) + require.Nil(t, err) +} + +func TestIfDirectiveWithAfterIfComment(t *testing.T) { + cond := ` @if # This comment is okay too + ((eq(len(m), 1) + OR # This is another comment + lt(90, len(h)) ) + AND + gt(len(f), 0)) ` + _, err := ParseIfDirective(cond) + require.Nil(t, err) +} + +func TestIfDirectiveErrMissingAt(t *testing.T) { + cond := ` if( ( eq(len(m), 1) + OR + lt(90, len(h))) + AND + gt(len(f), 0))` + _, err := ParseIfDirective(cond) + require.Contains(t, err.Error(), "Expecting argument name. Got: lex.Item [11] \"(\"") +} + +func TestIfDirectiveErrWhiteSpace(t *testing.T) { + cond := ` ` + ft, err := ParseIfDirective(cond) + require.Nil(t, err) + require.Nil(t, ft) +} + +func TestIfDirectiveNoParseErr(t *testing.T) { + cond := ` @if( ( eq(len(m), 1) + OR + lt(90, len(h))) + AND + gt(len(h), 0)) @if()` + _, err := ParseIfDirective(cond) + require.Contains(t, err.Error(), "Repeated filter or if at root") +} + +// TODO(Aman): test fails, This is a bad error message. +func TestIfDirectiveErrMissingRightRound(t *testing.T) { + cond := `@if( ( eq(len(m), 1) + OR + lt(90, len(h))) + AND + gt(len(h), 0)` + _, err := ParseIfDirective(cond) + require.Contains(t, err.Error(), "Unrecognized character inside a func: U+007B '{'") +}