Skip to content

Commit

Permalink
Add support for executing Conditional Upsert
Browse files Browse the repository at this point in the history
  • Loading branch information
mangalaman93 committed Jul 15, 2019
1 parent 68f5198 commit 2f87be6
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 21 deletions.
7 changes: 7 additions & 0 deletions dgraph/cmd/alpha/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,13 @@ func mutationHandler(w http.ResponseWriter, r *http.Request) {
return
}
}
if condText, ok := ms["cond"]; ok && condText != nil {
mu.Query, err = strconv.Unquote(string(condText.bs))
if err != nil {
x.SetStatus(w, x.ErrorInvalidRequest, err.Error())
return
}
}

case "application/rdf":
// Parse N-Quads.
Expand Down
182 changes: 161 additions & 21 deletions edgraph/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"math"
"os"
"sort"
"strconv"
"strings"
"time"
"unicode"
Expand Down Expand Up @@ -599,27 +600,135 @@ func doQueryInUpsert(ctx context.Context, mu *api.Mutation, gmu *gql.Mutation) (
return nil, errors.Errorf("upsert query block has no variables")
}

// TODO(Aman): allow multiple values for each variable.
// If a variable doesn't have any UID, we generate one ourselves later.
varToUID := make(map[string]string)
varToUID := make(map[string][]string)
for name, v := range qr.Vars {
if v.Uids == nil {
continue
}
if len(v.Uids.Uids) > 1 {
return nil, errors.Errorf("more than one values found for var (%s)", name)
} else if len(v.Uids.Uids) == 1 {
varToUID[name] = fmt.Sprintf("%d", v.Uids.Uids[0])

if len(v.Uids.Uids) > 0 {
uids := make([]string, len(v.Uids.Uids))
for i, u := range v.Uids.Uids {
uids[i] = fmt.Sprintf("%d", u)
}

varToUID[name] = uids
}
}

// Conditional mutation, we simply return in case condition is false
if isMut, err := processIfTree(gmu.CondTree, varToUID); err != nil {
return l, err
} else if !isMut {
gmu.Set = nil
gmu.Del = nil
return l, nil
}

updateMutations(gmu, varToUID)
return l, nil
}

func processIfTree(ft *gql.FilterTree, varToUID map[string][]string) (bool, error) {
if ft == nil {
return true, nil
}

switch strings.ToLower(ft.Op) {
case "not":
if len(ft.Child) > 1 {
return false, errors.Errorf("Expected 1 child for not but got %d.", len(ft.Child))
}

res, err := processIfTree(ft.Child[0], varToUID)
if err != nil {
return false, err
}
return !res, nil

case "and":
if len(ft.Child) < 2 {
return false, errors.Errorf("Expected >1 children for AND but got %d.", len(ft.Child))
}

for _, c := range ft.Child {
res, err := processIfTree(c, varToUID)
if err != nil {
return false, err
} else if !res {
return false, nil
}
}
return true, nil

case "or":
if len(ft.Child) < 2 {
return false, errors.Errorf("Expected >1 children for OR but got %d.", len(ft.Child))
}

for _, c := range ft.Child {
res, err := processIfTree(c, varToUID)
if err != nil {
return false, err
} else if res {
return true, nil
}
}
return false, nil

default:
if ft.Func != nil && !ft.Func.IsCount && len(ft.Func.Args) != 1 {
return false, errors.Errorf("function not supported in @if directive")
}

uids := varToUID[ft.Func.Attr]
count := int64(len(uids))
argVal, err := strconv.ParseInt(ft.Func.Args[0].Value, 10, 64)
if err != nil {
return false, err
}

// TODO(Aman): what if gt(5, count(g))?
switch ft.Func.Name {
case "ge":
return count >= argVal, nil
case "gt":
return count > argVal, nil
case "le":
return count <= argVal, nil
case "lt":
return count < argVal, nil
case "eq":
return count == argVal, nil
default:
return false, errors.Errorf("invalid function [%v] in if directive", ft.Func.Name)
}
}
}

// findvarsInCond finds variable used in the condition (the FilterTree)
// and adds them to the provided map (set) vars.
func findVarsInCond(ft *gql.FilterTree, vars map[string]struct{}) {
if ft == nil {
return
}

if ft.Func != nil && ft.Func.Attr != "" {
vars[ft.Func.Attr] = struct{}{}
}

// It's cool in case ft.Child is nil.
for _, c := range ft.Child {
findVarsInCond(c, vars)
}
}

// findVars finds all the variables used in mutation block
func findVars(gmu *gql.Mutation) []string {
vars := make(map[string]struct{})
findVarsInCond(gmu.CondTree, vars)

updateVars := func(s string) {
if strings.HasPrefix(s, "uid(") {
varName := s[4 : len(s)-1]
Expand Down Expand Up @@ -649,39 +758,63 @@ func findVars(gmu *gql.Mutation) []string {
// updateMutations does following transformations:
// * uid(v) -> 0x123 -- If v is defined in query block
// * uid(v) -> _:uid(v) -- Otherwise
func updateMutations(gmu *gql.Mutation, varToUID map[string]string) {
getNewVal := func(s string) string {
func updateMutations(gmu *gql.Mutation, varToUID map[string][]string) {
getNewVal := func(s string) []string {
if strings.HasPrefix(s, "uid(") {
varName := s[4 : len(s)-1]
if uid, ok := varToUID[varName]; ok {
return uid
if uids, ok := varToUID[varName]; ok {
return uids
}

return "_:" + s
return []string{"_:" + s}
}

return s
return []string{s}
}

getNewNQuad := func(nq *api.NQuad, s, o string) *api.NQuad {
// The following copy is fine because we only modify Subject and ObjectId.
// The pointer values are not modified across different copies of NQuad.
n := *nq

n.Subject = s
n.ObjectId = o
return &n
}

// Remove the mutations from gmu.Del when no UID was found.
gmuDel := gmu.Del[:0]
gmuDel := make([]*api.NQuad, 0, len(gmu.Del))
for _, nq := range gmu.Del {
nq.Subject = getNewVal(nq.Subject)
nq.ObjectId = getNewVal(nq.ObjectId)

if !strings.HasPrefix(nq.Subject, "_:uid(") &&
!strings.HasPrefix(nq.ObjectId, "_:uid(") {
newSubs := getNewVal(nq.Subject)
newObs := getNewVal(nq.ObjectId)

for _, s := range newSubs {
for _, o := range newObs {
// Blank node has no meaning in case of deletion.
if strings.HasPrefix(s, "_:uid(") ||
strings.HasPrefix(o, "_:uid(") {
continue
}

gmuDel = append(gmuDel, nq)
gmuDel = append(gmuDel, getNewNQuad(nq, s, o))
}
}
}
gmu.Del = gmuDel

// Update the values in mutation block from the query block.
gmuSet := make([]*api.NQuad, 0, len(gmu.Set))
for _, nq := range gmu.Set {
nq.Subject = getNewVal(nq.Subject)
nq.ObjectId = getNewVal(nq.ObjectId)
newSubs := getNewVal(nq.Subject)
newObs := getNewVal(nq.ObjectId)

for _, s := range newSubs {
for _, o := range newObs {
gmuSet = append(gmuSet, getNewNQuad(nq, s, o))
}
}
}
gmu.Set = gmuSet
}

// Query handles queries and returns the data.
Expand Down Expand Up @@ -962,6 +1095,13 @@ func parseMutationObject(mu *api.Mutation) (*gql.Mutation, error) {
if err := validateNQuads(res.Set, res.Del); err != nil {
return nil, err
}

ft, err := gql.ParseIfDirective(mu.Cond)
if err != nil {
return nil, err
}
res.CondTree = ft

return res, nil
}

Expand Down

0 comments on commit 2f87be6

Please sign in to comment.