diff --git a/dgraph/cmd/alpha/http.go b/dgraph/cmd/alpha/http.go index 080ff55d9be..c91b973fa87 100644 --- a/dgraph/cmd/alpha/http.go +++ b/dgraph/cmd/alpha/http.go @@ -326,6 +326,13 @@ func mutationHandler(w http.ResponseWriter, r *http.Request) { return } } + if condText, ok := ms["cond"]; ok && condText != nil { + mu.Query, err = strconv.Unquote(string(condText.bs)) + if err != nil { + x.SetStatus(w, x.ErrorInvalidRequest, err.Error()) + return + } + } case "application/rdf": // Parse N-Quads. diff --git a/edgraph/server.go b/edgraph/server.go index 7d23a04ae52..621b8ce192d 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -23,6 +23,7 @@ import ( "math" "os" "sort" + "strconv" "strings" "time" "unicode" @@ -599,27 +600,135 @@ func doQueryInUpsert(ctx context.Context, mu *api.Mutation, gmu *gql.Mutation) ( return nil, errors.Errorf("upsert query block has no variables") } - // TODO(Aman): allow multiple values for each variable. // If a variable doesn't have any UID, we generate one ourselves later. - varToUID := make(map[string]string) + varToUID := make(map[string][]string) for name, v := range qr.Vars { if v.Uids == nil { continue } - if len(v.Uids.Uids) > 1 { - return nil, errors.Errorf("more than one values found for var (%s)", name) - } else if len(v.Uids.Uids) == 1 { - varToUID[name] = fmt.Sprintf("%d", v.Uids.Uids[0]) + + if len(v.Uids.Uids) > 0 { + uids := make([]string, len(v.Uids.Uids)) + for i, u := range v.Uids.Uids { + uids[i] = fmt.Sprintf("%d", u) + } + + varToUID[name] = uids } } + // Conditional mutation, we simply return in case condition is false + if isMut, err := processIfTree(gmu.CondTree, varToUID); err != nil { + return l, err + } else if !isMut { + gmu.Set = nil + gmu.Del = nil + return l, nil + } + updateMutations(gmu, varToUID) return l, nil } +func processIfTree(ft *gql.FilterTree, varToUID map[string][]string) (bool, error) { + if ft == nil { + return true, nil + } + + switch strings.ToLower(ft.Op) { + case "not": + if len(ft.Child) > 1 { + return false, errors.Errorf("Expected 1 child for not but got %d.", len(ft.Child)) + } + + res, err := processIfTree(ft.Child[0], varToUID) + if err != nil { + return false, err + } + return !res, nil + + case "and": + if len(ft.Child) < 2 { + return false, errors.Errorf("Expected >1 children for AND but got %d.", len(ft.Child)) + } + + for _, c := range ft.Child { + res, err := processIfTree(c, varToUID) + if err != nil { + return false, err + } else if !res { + return false, nil + } + } + return true, nil + + case "or": + if len(ft.Child) < 2 { + return false, errors.Errorf("Expected >1 children for OR but got %d.", len(ft.Child)) + } + + for _, c := range ft.Child { + res, err := processIfTree(c, varToUID) + if err != nil { + return false, err + } else if res { + return true, nil + } + } + return false, nil + + default: + if ft.Func != nil && !ft.Func.IsCount && len(ft.Func.Args) != 1 { + return false, errors.Errorf("function not supported in @if directive") + } + + uids := varToUID[ft.Func.Attr] + count := int64(len(uids)) + argVal, err := strconv.ParseInt(ft.Func.Args[0].Value, 10, 64) + if err != nil { + return false, err + } + + // TODO(Aman): what if gt(5, count(g))? + switch ft.Func.Name { + case "ge": + return count >= argVal, nil + case "gt": + return count > argVal, nil + case "le": + return count <= argVal, nil + case "lt": + return count < argVal, nil + case "eq": + return count == argVal, nil + default: + return false, errors.Errorf("invalid function [%v] in if directive", ft.Func.Name) + } + } +} + +// findvarsInCond finds variable used in the condition (the FilterTree) +// and adds them to the provided map (set) vars. +func findVarsInCond(ft *gql.FilterTree, vars map[string]struct{}) { + if ft == nil { + return + } + + if ft.Func != nil && ft.Func.Attr != "" { + vars[ft.Func.Attr] = struct{}{} + } + + // It's cool in case ft.Child is nil. + for _, c := range ft.Child { + findVarsInCond(c, vars) + } +} + // findVars finds all the variables used in mutation block func findVars(gmu *gql.Mutation) []string { vars := make(map[string]struct{}) + findVarsInCond(gmu.CondTree, vars) + updateVars := func(s string) { if strings.HasPrefix(s, "uid(") { varName := s[4 : len(s)-1] @@ -649,39 +758,63 @@ func findVars(gmu *gql.Mutation) []string { // updateMutations does following transformations: // * uid(v) -> 0x123 -- If v is defined in query block // * uid(v) -> _:uid(v) -- Otherwise -func updateMutations(gmu *gql.Mutation, varToUID map[string]string) { - getNewVal := func(s string) string { +func updateMutations(gmu *gql.Mutation, varToUID map[string][]string) { + getNewVal := func(s string) []string { if strings.HasPrefix(s, "uid(") { varName := s[4 : len(s)-1] - if uid, ok := varToUID[varName]; ok { - return uid + if uids, ok := varToUID[varName]; ok { + return uids } - return "_:" + s + return []string{"_:" + s} } - return s + return []string{s} + } + + getNewNQuad := func(nq *api.NQuad, s, o string) *api.NQuad { + // The following copy is fine because we only modify Subject and ObjectId. + // The pointer values are not modified across different copies of NQuad. + n := *nq + + n.Subject = s + n.ObjectId = o + return &n } // Remove the mutations from gmu.Del when no UID was found. - gmuDel := gmu.Del[:0] + gmuDel := make([]*api.NQuad, 0, len(gmu.Del)) for _, nq := range gmu.Del { - nq.Subject = getNewVal(nq.Subject) - nq.ObjectId = getNewVal(nq.ObjectId) - - if !strings.HasPrefix(nq.Subject, "_:uid(") && - !strings.HasPrefix(nq.ObjectId, "_:uid(") { + newSubs := getNewVal(nq.Subject) + newObs := getNewVal(nq.ObjectId) + + for _, s := range newSubs { + for _, o := range newObs { + // Blank node has no meaning in case of deletion. + if strings.HasPrefix(s, "_:uid(") || + strings.HasPrefix(o, "_:uid(") { + continue + } - gmuDel = append(gmuDel, nq) + gmuDel = append(gmuDel, getNewNQuad(nq, s, o)) + } } } gmu.Del = gmuDel // Update the values in mutation block from the query block. + gmuSet := make([]*api.NQuad, 0, len(gmu.Set)) for _, nq := range gmu.Set { - nq.Subject = getNewVal(nq.Subject) - nq.ObjectId = getNewVal(nq.ObjectId) + newSubs := getNewVal(nq.Subject) + newObs := getNewVal(nq.ObjectId) + + for _, s := range newSubs { + for _, o := range newObs { + gmuSet = append(gmuSet, getNewNQuad(nq, s, o)) + } + } } + gmu.Set = gmuSet } // Query handles queries and returns the data. @@ -962,6 +1095,13 @@ func parseMutationObject(mu *api.Mutation) (*gql.Mutation, error) { if err := validateNQuads(res.Set, res.Del); err != nil { return nil, err } + + ft, err := gql.ParseIfDirective(mu.Cond) + if err != nil { + return nil, err + } + res.CondTree = ft + return res, nil }