diff --git a/gql/parser.go b/gql/parser.go index a1e11cdd5ef..04468a54606 100644 --- a/gql/parser.go +++ b/gql/parser.go @@ -2595,7 +2595,7 @@ func parseLanguageList(it *lex.ItemIterator) ([]string, error) { func validKeyAtRoot(k string) bool { switch k { - case "func", "orderasc", "orderdesc", "first", "offset", "after": + case "func", "orderasc", "orderdesc", "first", "offset", "after", "random": return true case "from", "to", "numpaths", "minweight", "maxweight": // Specific to shortest path @@ -2609,7 +2609,7 @@ func validKeyAtRoot(k string) bool { // Check for validity of key at non-root nodes. func validKey(k string) bool { switch k { - case "orderasc", "orderdesc", "first", "offset", "after": + case "orderasc", "orderdesc", "first", "offset", "after", "random": return true } return false diff --git a/query/query.go b/query/query.go index 9368da565a6..b7b423f723f 100644 --- a/query/query.go +++ b/query/query.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "math" + "math/rand" "sort" "strconv" "strings" @@ -114,6 +115,8 @@ type params struct { Count int // Offset is the value of the "offset" parameter. Offset int + // Random is the value of the "random" parameter + Random int // AfterUID is the value of the "after" parameter. AfterUID uint64 // DoCount is true if the count of the predicate is requested instead of its value. @@ -745,6 +748,15 @@ func (args *params) fill(gq *gql.GraphQuery) error { } args.Count = int(first) } + + if v, ok := gq.Args["random"]; ok { + random, err := strconv.ParseInt(v, 0, 32) + if err != nil { + return err + } + args.Random = int(random) + } + return nil } @@ -2298,6 +2310,13 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { } } + if sg.Params.Random > 0 { + if err = sg.applyRandom(ctx); err != nil { + rch <- err + return + } + } + // Here we consider handling count with filtering. We do this after // pagination because otherwise, we need to do the count with pagination // taken into account. For example, a PL might have only 50 entries but the @@ -2395,6 +2414,43 @@ func ProcessGraph(ctx context.Context, sg, parent *SubGraph, rch chan error) { rch <- childErr } +// stores index of a uid as the index in the uidMatrix (x) +// and index in the corresponding list of the uidMatrix (y) +type UidKey struct { + x int + y int +} + +// applies "random" to lists inside uidMatrix +// sg.Params.Random number of nodes are selected in each uid list +// duplicates are avoided (random selection without replacement) +// if sg.Params.Random is more than the number of available nodes +// all nodes are returned +func (sg *SubGraph) applyRandom(ctx context.Context) error { + sg.updateUidMatrix() + + for i := 0; i < len(sg.uidMatrix); i++ { + // shuffle the uid list and select the + // first sg.Params.Random uids + + uidList := sg.uidMatrix[i].Uids + + rand.Shuffle(len(uidList), func(i, j int) { + uidList[i], uidList[j] = uidList[j], uidList[i] + }) + + numRandom := sg.Params.Random + if sg.Params.Random > len(uidList) { + numRandom = len(uidList) + } + + sg.uidMatrix[i].Uids = uidList[:numRandom] + } + + sg.DestMap = codec.Merge(sg.uidMatrix) + return nil +} + // applyPagination applies count and offset to lists inside uidMatrix. func (sg *SubGraph) applyPagination(ctx context.Context) error { if sg.Params.Count == 0 && sg.Params.Offset == 0 { // No pagination. @@ -2638,7 +2694,7 @@ func (sg *SubGraph) sortAndPaginateUsingVar(ctx context.Context) error { func isValidArg(a string) bool { switch a { case "numpaths", "from", "to", "orderasc", "orderdesc", "first", "offset", "after", "depth", - "minweight", "maxweight": + "minweight", "maxweight", "random": return true } return false