diff --git a/gql/parser.go b/gql/parser.go index ed9052df882..cab1359af1c 100644 --- a/gql/parser.go +++ b/gql/parser.go @@ -31,9 +31,11 @@ import ( ) const ( - uid = "uid" - value = "val" - typ = "type" + uidFunc = "uid" + valueFunc = "val" + typFunc = "type" + lenFunc = "len" + countFunc = "count" ) // GraphQuery stores the parsed Query in a tree format. This gets converted to @@ -159,7 +161,7 @@ type FilterTree struct { // Arg stores an argument to a function. type Arg struct { Value string - IsValueVar bool // If argument is val(a) + IsValueVar bool // If argument is val(a), e.g. eq(name, val(a)) IsGraphQLVar bool } @@ -173,6 +175,7 @@ type Function struct { NeedsVar []VarContext // If the function requires some variable IsCount bool // gt(count(friends),0) IsValueVar bool // eq(val(s), 5) + IsLenVar bool // eq(len(s), 5) } // filterOpPrecedence is a map from filterOp (a string) to its precedence. @@ -420,7 +423,7 @@ func substituteVariablesFilter(f *FilterTree, vmap varMap) error { } for idx, v := range f.Func.Args { - if f.Func.Name == uid { + if f.Func.Name == uidFunc { // This is to support GraphQL variables in uid functions. idVal, ok := vmap[v.Value] if !ok { @@ -562,7 +565,7 @@ func ParseWithNeedVars(r Request, needVars []string) (res Result, rerr error) { return res, err } - // Substitute all variables with corresponding values + // Substitute all graphql variables with corresponding values if err := substituteVariables(qu, vmap); err != nil { return res, err } @@ -1186,7 +1189,7 @@ func parseArguments(it *lex.ItemIterator, gq *GraphQuery) (result []pair, rerr e it.Next() item = it.Item() var val string - if item.Val == value { + if item.Val == valueFunc { count, err := parseVarList(it, gq) if err != nil { return result, err @@ -1263,9 +1266,13 @@ func (f *FilterTree) stringHelper(buf *bytes.Buffer) { buf.WriteRune(' ') if f.Func.IsCount { buf.WriteString("count(") + } else if f.Func.IsValueVar { + buf.WriteString("val(") + } else if f.Func.IsLenVar { + buf.WriteString("len(") } buf.WriteString(f.Func.Attr) - if f.Func.IsCount { + if f.Func.IsCount || f.Func.IsValueVar || f.Func.IsLenVar { buf.WriteRune(')') } if len(f.Func.Lang) > 0 { @@ -1544,7 +1551,7 @@ L: return nil, err } seenFuncArg = true - if nestedFunc.Name == value { + if nestedFunc.Name == valueFunc { if len(nestedFunc.NeedsVar) > 1 { return nil, itemInFunc.Errorf("Multiple variables not allowed in a function") } @@ -1559,13 +1566,25 @@ L: } function.NeedsVar = append(function.NeedsVar, nestedFunc.NeedsVar...) function.NeedsVar[0].Typ = ValueVar - } else { - if nestedFunc.Name != "count" { - return nil, itemInFunc.Errorf("Only val/count allowed as function "+ - "within another. Got: %s", nestedFunc.Name) + } else if nestedFunc.Name == lenFunc { + if len(nestedFunc.NeedsVar) > 1 { + return nil, + itemInFunc.Errorf("Multiple variables not allowed in len function") } + if !isInequalityFn(function.Name) { + return nil, + itemInFunc.Errorf("len function only allowed inside inequality" + + " function") + } + function.Attr = nestedFunc.NeedsVar[0].Name + function.IsLenVar = true + function.NeedsVar = append(function.NeedsVar, nestedFunc.NeedsVar...) + } else if nestedFunc.Name == countFunc { function.Attr = nestedFunc.Attr function.IsCount = true + } else { + return nil, itemInFunc.Errorf("Only val/count/len allowed as function "+ + "within another. Got: %s", nestedFunc.Name) } expectArg = false continue @@ -1651,7 +1670,7 @@ L: if isDollar { val = "$" + val isDollar = false - if function.Name == uid && gq != nil { + if function.Name == uidFunc && gq != nil { if len(gq.Args["id"]) > 0 { return nil, itemInFunc.Errorf("Only one GraphQL variable " + "allowed inside uid function.") @@ -1665,7 +1684,9 @@ L: } // Unlike other functions, uid function has no attribute, everything is args. - if len(function.Attr) == 0 && function.Name != uid && function.Name != typ { + if len(function.Attr) == 0 && function.Name != uidFunc && + function.Name != typFunc { + if strings.ContainsRune(itemInFunc.Val, '"') { return nil, itemInFunc.Errorf("Attribute in function"+ " must not be quoted with \": %s", itemInFunc.Val) @@ -1679,7 +1700,7 @@ L: } function.Lang = val expectLang = false - } else if function.Name != uid { + } else if function.Name != uidFunc { // For UID function. we set g.UID function.Args = append(function.Args, Arg{Value: val}) } @@ -1689,13 +1710,20 @@ L: } expectArg = false - if function.Name == value { + if function.Name == valueFunc { // E.g. @filter(gt(val(a), 10)) function.NeedsVar = append(function.NeedsVar, VarContext{ Name: val, Typ: ValueVar, }) - } else if function.Name == uid { + } else if function.Name == lenFunc { + // E.g. @filter(gt(len(a), 10)) + // TODO(Aman): type could be ValueVar too! + function.NeedsVar = append(function.NeedsVar, VarContext{ + Name: val, + Typ: UidVar, + }) + } else if function.Name == uidFunc { // uid function could take variables as well as actual uids. // If we can parse the value that means its an uid otherwise a variable. uid, err := strconv.ParseUint(val, 0, 64) @@ -1723,11 +1751,11 @@ L: } } - if function.Name != uid && function.Name != typ && len(function.Attr) == 0 { + if function.Name != uidFunc && function.Name != typFunc && len(function.Attr) == 0 { return nil, it.Errorf("Got empty attr for function: [%s]", function.Name) } - if function.Name == typ && len(function.Args) != 1 { + if function.Name == typFunc && len(function.Args) != 1 { return nil, it.Errorf("type function only supports one argument. Got: %v", function.Args) } @@ -2453,7 +2481,7 @@ func getRoot(it *lex.ItemIterator) (gq *GraphQuery, rerr error) { } } - if peekIt[0].Val == uid { + if peekIt[0].Val == uidFunc { gen, err := parseFunction(it, gq) if err != nil { return gq, err @@ -2509,7 +2537,7 @@ func getRoot(it *lex.ItemIterator) (gq *GraphQuery, rerr error) { item = it.Item() } - if val == "" && item.Val == value { + if val == "" && item.Val == valueFunc { count, err := parseVarList(it, gq) if err != nil { return nil, err @@ -2524,7 +2552,7 @@ func getRoot(it *lex.ItemIterator) (gq *GraphQuery, rerr error) { // Get language list, if present items, err := it.Peek(1) if err == nil && items[0].Typ == itemLeftRound { - if (key == "orderasc" || key == "orderdesc") && val != value { + if (key == "orderasc" || key == "orderdesc") && val != valueFunc { return nil, it.Errorf("Expected val(). Got %s() with order.", val) } } @@ -2699,7 +2727,7 @@ func godeep(it *lex.ItemIterator, gq *GraphQuery) error { continue } else if isAggregator(valLower) { child := &GraphQuery{ - Attr: value, + Attr: valueFunc, Args: make(map[string]string), Var: varName, IsInternal: true, @@ -2727,7 +2755,7 @@ func godeep(it *lex.ItemIterator, gq *GraphQuery) error { child.Attr = attr child.IsInternal = false } else { - if it.Item().Val != value { + if it.Item().Val != valueFunc { return it.Errorf("Only variables allowed in aggregate functions. Got: %v", it.Item().Val) } @@ -2792,7 +2820,7 @@ func godeep(it *lex.ItemIterator, gq *GraphQuery) error { IsInternal: true, } switch item.Val { - case value: + case valueFunc: count, err := parseVarList(it, child) if err != nil { return err @@ -2835,10 +2863,10 @@ func godeep(it *lex.ItemIterator, gq *GraphQuery) error { } if peekIt[0].Typ == itemRightRound { return it.Errorf("Cannot use count(), please use count(uid)") - } else if peekIt[0].Val == uid && peekIt[1].Typ == itemRightRound { + } else if peekIt[0].Val == uidFunc && peekIt[1].Typ == itemRightRound { if gq.IsGroupby { // count(uid) case which occurs inside @groupby - val = uid + val = uidFunc // Skip uid) it.Next() it.Next() @@ -2855,7 +2883,7 @@ func godeep(it *lex.ItemIterator, gq *GraphQuery) error { it.Next() } continue - } else if valLower == value { + } else if valLower == valueFunc { if varName != "" { return it.Errorf("Cannot assign a variable to val()") } @@ -2889,7 +2917,7 @@ func godeep(it *lex.ItemIterator, gq *GraphQuery) error { gq.Children = append(gq.Children, child) curp = nil continue - } else if valLower == uid { + } else if valLower == uidFunc { if count == seen { return it.Errorf("Count of a variable is not allowed") } diff --git a/gql/parser_test.go b/gql/parser_test.go index c9f5f1f7532..ada26182184 100644 --- a/gql/parser_test.go +++ b/gql/parser_test.go @@ -974,6 +974,103 @@ func TestParseQueryWithVarInIneqError(t *testing.T) { require.Contains(t, err.Error(), "Multiple variables not allowed in a function") } +func TestLenFunctionWithMultipleVariableError(t *testing.T) { + query := ` + { + var(func: uid(0x0a)) { + fr as friends { + a as age + } + } + + me(func: uid(fr)) @filter(gt(len(a, b), 10)) { + name + } + } +` + // Multiple vars not allowed. + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Multiple variables not allowed in len function") +} + +func TestLenFunctionInsideUidError(t *testing.T) { + query := ` + { + var(func: uid(0x0a)) { + fr as friends { + a as age + } + } + + me(func: uid(fr)) @filter(uid(len(a), 10)) { + name + } + } +` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "len function only allowed inside inequality") +} + +func TestLenFunctionWithNoVariable(t *testing.T) { + query := ` + { + var(func: uid(0x0a)) { + fr as friends { + a as age + } + } + + me(func: uid(fr)) @filter(len(), 10) { + name + } + } +` + _, err := Parse(Request{Str: query}) + require.Error(t, err) + require.Contains(t, err.Error(), "Got empty attr for function") +} + +func TestLenAsSecondArgumentError(t *testing.T) { + query := ` + { + var(func: uid(0x0a)) { + fr as friends { + a as age + } + } + + me(func: uid(fr)) @filter(10, len(fr)) { + name + } + } +` + _, err := Parse(Request{Str: query}) + // TODO(pawan) - Error message can be improved. We should validate function names from a + // whitelist. + require.Error(t, err) +} + +func TestCountWithLenFunctionError(t *testing.T) { + query := ` + { + var(func: uid(0x0a)) { + fr as friends { + a as age + } + } + + me(func: uid(fr)) @filter(count(name), len(fr)) { + name + } + } +` + _, err := Parse(Request{Str: query}) + // TODO(pawan) - Error message can be improved. + require.Error(t, err) +} + func TestParseQueryWithVarInIneq(t *testing.T) { query := ` { @@ -4042,7 +4139,7 @@ func TestEqUidFunctionErr(t *testing.T) { ` _, err := Parse(Request{Str: query}) require.Error(t, err) - require.Contains(t, err.Error(), "Only val/count allowed as function within another. Got: uid") + require.Contains(t, err.Error(), "Only val/count/len allowed as function within another. Got: uid") } func TestAggRoot1(t *testing.T) { diff --git a/query/query.go b/query/query.go index dacf58669a4..288ffc3741c 100644 --- a/query/query.go +++ b/query/query.go @@ -193,6 +193,7 @@ type Function struct { Args []gql.Arg // Contains the arguments of the function. IsCount bool // gt(count(friends),0) IsValueVar bool // eq(val(s), 10) + IsLenVar bool // eq(len(s), 10) } // SubGraph is the way to represent data. It contains both the request parameters and the response. @@ -276,6 +277,7 @@ func (sg *SubGraph) createSrcFunction(gf *gql.Function) { Args: append(gf.Args[:0:0], gf.Args...), IsCount: gf.IsCount, IsValueVar: gf.IsValueVar, + IsLenVar: gf.IsLenVar, } // type function is just an alias for eq(type, "dgraph.type"). @@ -284,6 +286,7 @@ func (sg *SubGraph) createSrcFunction(gf *gql.Function) { sg.SrcFunc.Name = "eq" sg.SrcFunc.IsCount = false sg.SrcFunc.IsValueVar = false + sg.SrcFunc.IsLenVar = false return } @@ -1895,13 +1898,32 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { // when multiple filters replace their sg.DestUIDs sg.DestUIDs = &pb.List{Uids: sg.SrcUIDs.Uids} } else { - if sg.SrcFunc != nil && isInequalityFn(sg.SrcFunc.Name) && sg.SrcFunc.IsValueVar { + isInequalityFn := sg.SrcFunc != nil && isInequalityFn(sg.SrcFunc.Name) + if isInequalityFn && sg.SrcFunc.IsValueVar { // This is a ineq function which uses a value variable. err = sg.applyIneqFunc() if parent != nil { rch <- err return } + } else if isInequalityFn && sg.SrcFunc.IsLenVar { + // Safe to access 0th element here because if no variable was given, parser would throw + // an error. + val := sg.SrcFunc.Args[0].Value + src := types.Val{Tid: types.StringID, Value: []byte(val)} + dst, err := types.Convert(src, types.IntID) + if err != nil { + // TODO(Aman): needs to do parent check? + rch <- errors.Wrapf(err, "invalid argument %v. Comparing with different type", val) + return + } + + curVal := types.Val{Tid: types.IntID, Value: int64(len(sg.SrcUIDs.Uids))} + if types.CompareVals(sg.SrcFunc.Name, curVal, dst) { + sg.DestUIDs.Uids = sg.SrcUIDs.Uids + } else { + sg.DestUIDs.Uids = nil + } } else { taskQuery, err := createTaskQuery(sg) if err != nil { diff --git a/query/query0_test.go b/query/query0_test.go index 7fa982bec29..3b0ccd3c454 100644 --- a/query/query0_test.go +++ b/query/query0_test.go @@ -1756,6 +1756,158 @@ func TestCountUidToVar(t *testing.T) { require.JSONEq(t, `{"data": {"me":[{"score": 3}]}}`, js) } +func TestFilterUsingLenFunction(t *testing.T) { + tests := []struct { + name, in, out string + }{ + { + "Eq length should return results", + `{ + var(func: has(school), first: 3) { + f as uid + } + + me(func: uid(f)) @filter(eq(len(f), 3)) { + count(uid) + } + }`, + `{"data": {"me":[{"count": 3}]}}`, + }, + { + "Eq length should return empty results", + `{ + var(func: has(school), first: 3) { + f as uid + } + + me(func: uid(f)) @filter(eq(len(f), 0)) { + count(uid) + } + }`, + `{"data": {"me":[{"count": 0}]}}`, + }, + { + "Ge length should return results", + `{ + var(func: has(school), first: 3) { + f as uid + } + + me(func: uid(f)) @filter(ge(len(f), 0)) { + count(uid) + } + }`, + `{"data": {"me":[{"count": 3}]}}`, + }, + { + "Lt length should return results", + `{ + var(func: has(school), first: 3) { + f as uid + } + + me(func: uid(f)) @filter(lt(len(f), 100)) { + count(uid) + } + }`, + + `{"data": {"me":[{"count": 3}]}}`, + }, + { + "Multiple length conditions", + `{ + var(func: has(school), first: 3) { + f as uid + } + + f2 as var(func: has(name), first: 5) + + me(func: uid(f2)) @filter(lt(len(f), 100) AND lt(len(f2), 10)) { + count(uid) + } + }`, + + `{"data": {"me":[{"count": 5}]}}`, + }, + { + "Filter in child with true result", + `{ + var(func: has(school), first: 3) { + f as uid + } + + me(func: uid(f)) { + name + friend @filter(lt(len(f), 100)) { + name + } + } + }`, + `{"data":{"me":[{"name":"Michonne","friend":[{"name":"Rick Grimes"}, + {"name":"Glenn Rhee"},{"name":"Daryl Dixon"},{"name":"Andrea"}]}, + {"name":"Rick Grimes","friend":[{"name":"Michonne"}]}, + {"name":"Glenn Rhee"}]}}`, + }, + { + "Filter in child with false result", + `{ + var(func: has(school), first: 3) { + f as uid + } + + me(func: uid(f)) { + name + friend @filter(gt(len(f), 100)) { + name + } + } + }`, + + `{"data":{"me":[{"name":"Michonne"},{"name":"Rick Grimes"}, + {"name":"Glenn Rhee"}]}}`, + }, + } + + for _, tc := range tests { + t.Log("Running: ", tc.name) + js := processQueryNoErr(t, tc.in) + require.JSONEq(t, tc.out, js) + } +} + +func TestCountOnVarAtRootErr(t *testing.T) { + query := ` + { + var(func: has(school), first: 3) { + f as count(uid) + } + + me(func: len(f)) { + score: math(f) + } + } + ` + _, err := processQuery(context.Background(), t, query) + require.Error(t, err) + require.Contains(t, err.Error(), "Function name: len is not valid") +} + +func TestFilterUsingLenFunctionWithMath(t *testing.T) { + query := ` + { + var(func: has(school), first: 3) { + f as count(uid) + } + + me(func: uid(f)) @filter(lt(len(f), 100)) { + score: math(f) + } + } + ` + js := processQueryNoErr(t, query) + require.JSONEq(t, `{"data": {"me":[{"score": 3}]}}`, js) +} + func TestCountUidToVarMultiple(t *testing.T) { query := ` { diff --git a/worker/task.go b/worker/task.go index 4157c46d084..bc3753506a9 100644 --- a/worker/task.go +++ b/worker/task.go @@ -1438,7 +1438,7 @@ func parseSrcFn(q *pb.Query) (*functionContext, error) { case notAFunction: fc.n = len(q.UidList.Uids) case aggregatorFn: - // confirm agrregator could apply on the attributes + // confirm aggregator could apply on the attributes typ, err := schema.State().TypeOf(attr) if err != nil { return nil, errors.Errorf("Attribute %q is not scalar-type", attr)