Skip to content

Commit

Permalink
Add sortBy() macro to lists extension.
Browse files Browse the repository at this point in the history
  • Loading branch information
seirl committed Oct 14, 2024
1 parent 4d58cb2 commit 40648f4
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 7 deletions.
1 change: 1 addition & 0 deletions ext/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ go_library(
"//common/types/ref:go_default_library",
"//common/types/traits:go_default_library",
"//interpreter:go_default_library",
"//parser:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
"@org_golang_google_protobuf//reflect/protoreflect:go_default_library",
"@org_golang_google_protobuf//types/known/structpb",
Expand Down
16 changes: 16 additions & 0 deletions ext/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,22 @@ Examples:
[1, "b"].sort() // error
[[1, 2, 3]].sort() // error

### SortBy

**Introduced in version 2**

Sorts a list by a key value, i.e., the order is determined by the result of
an expression applied to each element of the list.

Examples:

[
Player { name: "foo", score: 0 },
Player { name: "bar", score: -10 },
Player { name: "baz", score: 1000 },
].sortBy(e, e.score).map(e, e.name)
== ["bar", "foo", "baz"]

## Sets

Sets provides set relationship tests.
Expand Down
141 changes: 134 additions & 7 deletions ext/lists.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import (
"sort"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
"github.com/google/cel-go/parser"
)

var comparableTypes = []*cel.Type{
Expand Down Expand Up @@ -124,6 +126,25 @@ var comparableTypes = []*cel.Type{
// ["b", "c", "a"].sort() // return ["a", "b", "c"]
// [1, "b"].sort() // error
// [[1, 2, 3]].sort() // error
//
// # SortBy
//
// Sorts a list by a key value, i.e., the order is determined by the result of
// an expression applied to each element of the list.
// The output of the key expression must be a comparable type, otherwise the
// function will return an error.
//
// <list(T)>.sortBy(<bindingName>, <keyExpr>) -> <list(T)>
// keyExpr returns a value in {int, uint, double, bool, duration, timestamp, string, bytes}

// Examples:
//
// [
// Player { name: "foo", score: 0 },
// Player { name: "bar", score: -10 },
// Player { name: "baz", score: 1000 },
// ].sortBy(e, e.score).map(e, e.name)
// == ["bar", "foo", "baz"]

func Lists(options ...ListsOption) cel.EnvOption {
l := &listsLib{
Expand Down Expand Up @@ -258,6 +279,37 @@ func (lib listsLib) CompileOptions() []cel.EnvOption {
)...,
)
opts = append(opts, sortDecl)
opts = append(opts, cel.Macros(cel.ReceiverMacro("sortBy", 2, sortByMacro)))
opts = append(opts, cel.Function("@sortByAssociatedKeys",
append(
templatedOverloads(comparableTypes, func(u *cel.Type) cel.FunctionOpt {
return cel.MemberOverload(
fmt.Sprintf("list_%s_sortByAssociatedKeys", u.TypeName()),
[]*cel.Type{listType, cel.ListType(u)}, listType,
)
}),
cel.SingletonBinaryBinding(
func(arg1 ref.Val, arg2 ref.Val) ref.Val {
list, ok := arg1.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg1)
}
keys, ok := arg2.(traits.Lister)
if !ok {
return types.MaybeNoSuchOverloadErr(arg2)
}
sorted, err := sortListByAssociatedKeys(list, keys)
if err != nil {
return types.WrapErr(err)
}

return sorted
},
// List traits
traits.ListerType,
),
)...,
))

opts = append(opts, cel.Function("lists.range",
cel.Overload("lists_range",
Expand Down Expand Up @@ -375,31 +427,106 @@ func flatten(list traits.Lister, depth int64) ([]ref.Val, error) {
}

func sortList(list traits.Lister) (ref.Val, error) {
return sortListByAssociatedKeys(list, list)
}

// Internal function used for the implementation of sort() and sortBy().
//
// Sorts a list of arbitrary elements, according to the order produced by sorting
// another list of comparable elements. If the element type of the keys is not
// comparable or the element types are not the same, the function will produce an error.
//
// <list(T)>.@sortByAssociatedKeys(<list(U)>) -> <list(T)>
// U in {int, uint, double, bool, duration, timestamp, string, bytes}
//
// Example:
//
// ["foo", "bar", "baz"].@sortByAssociatedKeys([3, 1, 2]) // return ["bar", "baz", "foo"]
func sortListByAssociatedKeys(list, keys traits.Lister) (ref.Val, error) {
listLength := list.Size().(types.Int)
keysLength := keys.Size().(types.Int)
if listLength != keysLength {
return nil, fmt.Errorf(
"@sortByAssociatedKeys() expected a list of the same size as the associated keys list, but got %d and %d elements respectively",
listLength,
keysLength,
)
}
if listLength == 0 {
return list, nil
}
elem := list.Get(types.IntZero)
elem := keys.Get(types.IntZero)
if _, ok := elem.(traits.Comparer); !ok {
return nil, fmt.Errorf("list elements must be comparable")
}

sorted := make([]ref.Val, 0, listLength)
sortedIndices := make([]ref.Val, 0, listLength)
for i := types.IntZero; i < listLength; i++ {
val := list.Get(i)
if val.Type() != elem.Type() {
if keys.Get(i).Type() != elem.Type() {
return nil, fmt.Errorf("list elements must have the same type")
}
sorted = append(sorted, val)
sortedIndices = append(sortedIndices, i)
}

sort.Slice(sorted, func(i, j int) bool {
return sorted[i].(traits.Comparer).Compare(sorted[j]) == types.IntNegOne
sort.Slice(sortedIndices, func(i, j int) bool {
iKey := keys.Get(sortedIndices[i])
jKey := keys.Get(sortedIndices[j])
return iKey.(traits.Comparer).Compare(jKey) == types.IntNegOne
})

sorted := make([]ref.Val, 0, listLength)

for _, sortedIdx := range sortedIndices {
sorted = append(sorted, list.Get(sortedIdx))
}
return types.DefaultTypeAdapter.NativeToValue(sorted), nil
}

// sortByMacro transforms an expression like:
//
// mylistExpr.sortBy(e, -math.abs(e))
//
// into:
//
// cel.bind(
// __sortBy_input__,
// myListExpr,
// __sortBy_input__.@sortByAssociatedKeys(__sortBy_input__.map(e, -math.abs(e))
// )
func sortByMacro(meh cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) {
varIdent := meh.NewIdent("@__sortBy_input__")
varName := varIdent.AsIdent()

targetKind := target.Kind()
if targetKind != ast.ListKind &&
targetKind != ast.SelectKind &&
targetKind != ast.IdentKind &&
targetKind != ast.ComprehensionKind && targetKind != ast.CallKind {
return nil, meh.NewError(target.ID(), fmt.Sprintf("sortBy can only be applied to a list, identifier, comprehension, call or select expression"))
}

mapCompr, err := parser.MakeMap(meh, meh.Copy(varIdent), args)
if err != nil {
return nil, err
}
callExpr := meh.NewMemberCall("@sortByAssociatedKeys",
meh.Copy(varIdent),
mapCompr,
)

bindExpr := meh.NewComprehension(
meh.NewList(),
"#unused",
varName,
target,
meh.NewLiteral(types.False),
varIdent,
callExpr,
)

return bindExpr, nil
}

func distinctList(list traits.Lister) (ref.Val, error) {
listLength := list.Size().(types.Int)
if listLength == 0 {
Expand Down
6 changes: 6 additions & 0 deletions ext/lists_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ func TestLists(t *testing.T) {
{expr: `[4, 3, 2, 1].sort() == [1, 2, 3, 4]`},
{expr: `["d", "a", "b", "c"].sort() == ["a", "b", "c", "d"]`},
{expr: `["d", 3, 2, "c"].sort() == ["a", "b", "c", "d"]`, err: "list elements must have the same type"},
{expr: `[].sortBy(e, e) == []`},
{expr: `["a"].sortBy(e, e) == ["a"]`},
{expr: `[-3, 1, -5, -2, 4].sortBy(e, -(e * e)) == [-5, 4, -3, -2, 1]`},
{expr: `[-3, 1, -5, -2, 4].map(e, e * 2).sortBy(e, -(e * e)) == [-10, 8, -6, -4, 2]`},
{expr: `lists.range(3).sortBy(e, -e) == [2, 1, 0]`},
{expr: `["a", "c", "b", "first"].sortBy(e, e == "first" ? "" : e) == ["first", "a", "b", "c"]`},
{expr: `[].distinct() == []`},
{expr: `[1].distinct() == [1]`},
{expr: `[-2, 5, -2, 1, 1, 5, -2, 1].distinct() == [-2, 5, 1]`},
Expand Down

0 comments on commit 40648f4

Please sign in to comment.