Skip to content

Commit

Permalink
Add support for multiple uids in uid_in function (dgraph-io#5292)
Browse files Browse the repository at this point in the history
Support multiple uids in UID_IN
  • Loading branch information
all-seeing-code authored and dna2github committed Jul 18, 2020
1 parent e7b0ab2 commit c5b00e7
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 14 deletions.
9 changes: 6 additions & 3 deletions gql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -1556,11 +1556,11 @@ loop:
return nil
}

// parseIneqArgs will try to parse the arguments inside an array ([]). If the values
// parseFuncArgs will try to parse the arguments inside an array ([]). If the values
// are prefixed with $ they are treated as Gql variables, otherwise they are used as scalar values.
// Returns nil on success while appending arguments to the function Args slice. Otherwise
// returns an error, which can be a parsing or value error.
func parseIneqArgs(it *lex.ItemIterator, g *Function) error {
func parseFuncArgs(it *lex.ItemIterator, g *Function) error {
var expectArg, isDollar bool

expectArg = true
Expand Down Expand Up @@ -1764,7 +1764,10 @@ L:
err = parseGeoArgs(it, function)

case IsInequalityFn(function.Name):
err = parseIneqArgs(it, function)
err = parseFuncArgs(it, function)

case function.Name == "uid_in":
err = parseFuncArgs(it, function)

default:
err = itemInFunc.Errorf("Unexpected character [ while parsing request.")
Expand Down
41 changes: 39 additions & 2 deletions gql/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4856,16 +4856,53 @@ func TestParseGraphQLVarArray(t *testing.T) {
require.Equal(t, 1, len(gq.Query))
require.Equal(t, "eq", gq.Query[0].Func.Name)
require.Equal(t, tc.args, len(gq.Query[0].Func.Args))
found := false
var found bool
for _, val := range tc.vars {
found = false
for _, arg := range gq.Query[0].Func.Args {
if val == arg.Value {
found = true
break
}
}
require.True(t, found, "vars not matched: %v", tc.vars)
}
}
}

func TestParseGraphQLVarArrayUID_IN(t *testing.T) {
tests := []struct {
q string
vars map[string]string
args int
}{
// uid_in test cases (uids and predicate inside uid_in are dummy)
{q: `query test($a: string){q(func: uid_in(director.film, [$a])) {name}}`,
vars: map[string]string{"$a": "0x4e472a"}, args: 1},
{q: `query test($a: string, $b: string){q(func: uid_in(director.film, [$a, $b])) {name}}`,
vars: map[string]string{"$a": "0x4e472a", "$b": "0x4e9545"}, args: 2},
{q: `query test($a: string){q(func: uid_in(name, [$a, "0x4e9545"])) {name}}`,
vars: map[string]string{"$a": "0x4e472a"}, args: 2},
{q: `query test($a: string){q(func: uid_in(name, ["0x4e9545", $a])) {name}}`,
vars: map[string]string{"$a": "0x4e472a"}, args: 2},
}
for _, tc := range tests {
gq, err := Parse(Request{Str: tc.q, Variables: tc.vars})
require.NoError(t, err)
require.Equal(t, 1, len(gq.Query))
require.Equal(t, "uid_in", gq.Query[0].Func.Name)
require.Equal(t, tc.args, len(gq.Query[0].Func.Args))
var found bool
for _, val := range tc.vars {
found = false
for _, arg := range gq.Query[0].Func.Args {
if val == arg.Value {
found = true
break
}
}
require.True(t, found, "vars not matched: %v", tc.vars)
}
require.True(t, found, "vars not matched: %v", tc.vars)
}
}

Expand Down
64 changes: 64 additions & 0 deletions query/query1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,70 @@ func TestUidInFunction2(t *testing.T) {
js)
}

func TestUidInFunction3a(t *testing.T) {
// query at top level with unsorted input UIDs
query := `
{
me(func: UID(1, 23, 24)) @filter(uid_in(school, [5001, 5000])) {
name
}
}`
js := processQueryNoErr(t, query)
require.JSONEq(t, `{"data": {"me":[{"name":"Michonne"},{"name":"Rick Grimes"},{"name":"Glenn Rhee"}]}}`, js)
}

func TestUidInFunction3b(t *testing.T) {
// query at top level with sorted input UIDs
query := `
{
me(func: UID(1, 23, 24)) @filter(uid_in(school, [5000, 5001])) {
name
}
}`
js := processQueryNoErr(t, query)
require.JSONEq(t, `{"data": {"me":[{"name":"Michonne"},{"name":"Rick Grimes"},{"name":"Glenn Rhee"}]}}`, js)
}

func TestUidInFunction3c(t *testing.T) {
// query at top level with no UIDs present in predicate
query := `
{
me(func: UID(1, 23, 24)) @filter(uid_in(school, [500, 501])) {
name
}
}`
js := processQueryNoErr(t, query)
require.JSONEq(t, `{"data":{"me":[]}}`, js)
}
func TestUidInFunction4a(t *testing.T) {
// query inside with sorted input UIDs
query := `
{
me(func: uid(1, 23, 24 )) {
friend @filter(uid_in(school, [5000, 5001])) {
name
}
}
}`
js := processQueryNoErr(t, query)
require.JSONEq(t, `{"data": {"me":[{"friend":[{"name":"Rick Grimes"}, {"name":"Glenn Rhee"},{"name":"Daryl Dixon"},{"name":"Andrea"}]},{"friend":[{"name":"Michonne"}]}]}}`,
js)
}

func TestUidInFunction4b(t *testing.T) {
// query inside with unsorted input UIDs + not all uids present in predicate
query := `
{
me(func: uid(1, 23, 24 )) {
friend @filter(uid_in(school, [5001, 500])) {
name
}
}
}`
js := processQueryNoErr(t, query)
require.JSONEq(t, `{"data":{"me":[{"friend":[{"name":"Rick Grimes"},{"name":"Andrea"}]}]}}`,
js)
}
func TestUidInFunctionAtRoot(t *testing.T) {

query := `
Expand Down
26 changes: 17 additions & 9 deletions worker/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ func (qs *queryState) handleUidPostings(
if i == 0 {
span.Annotate(nil, "UidInFn")
}
reqList := &pb.List{Uids: []uint64{srcFn.uidPresent}}
reqList := &pb.List{Uids: srcFn.uidsPresent}
topts := posting.ListOptions{
ReadTs: args.q.ReadTs,
AfterUid: 0,
Expand Down Expand Up @@ -1583,7 +1583,7 @@ type functionContext struct {
ineqValueToken string
n int
threshold int64
uidPresent uint64
uidsPresent []uint64
fname string
fnType FuncType
regex *cregexp.Regexp
Expand Down Expand Up @@ -1816,17 +1816,25 @@ func parseSrcFn(ctx context.Context, q *pb.Query) (*functionContext, error) {
}
checkRoot(q, fc)
case uidInFn:
if err = ensureArgsCount(q.SrcFunc, 1); err != nil {
if len(q.SrcFunc.Args) == 0 {
err := errors.Errorf("Function '%s' requires atleast 1 argument, but got %d (%v)",
q.SrcFunc.Name, len(q.SrcFunc.Args), q.SrcFunc.Args)
return nil, err
}
fc.uidPresent, err = strconv.ParseUint(q.SrcFunc.Args[0], 0, 64)
if err != nil {
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrSyntax {
return nil, errors.Errorf("Value %q in %s is not a number",
q.SrcFunc.Args[0], q.SrcFunc.Name)
for _, arg := range q.SrcFunc.Args {
uidParsed, err := strconv.ParseUint(arg, 0, 64)
if err != nil {
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrSyntax {
return nil, errors.Errorf("Value %q in %s is not a number",
arg, q.SrcFunc.Name)
}
return nil, err
}
return nil, err
fc.uidsPresent = append(fc.uidsPresent, uidParsed)
}
sort.Slice(fc.uidsPresent, func(i, j int) bool {
return fc.uidsPresent[i] < fc.uidsPresent[j]
})
checkRoot(q, fc)
if fc.isFuncAtRoot {
return nil, errors.Errorf("uid_in function not allowed at root")
Expand Down

0 comments on commit c5b00e7

Please sign in to comment.