Skip to content

Commit

Permalink
Merge pull request #2678 from darkdrag00nv2/filter_array
Browse files Browse the repository at this point in the history
Introduce `filter` in Fixed/Variable sized Array types
  • Loading branch information
SupunS authored Aug 14, 2023
2 parents 016525d + 9c2eb9d commit af119ed
Show file tree
Hide file tree
Showing 4 changed files with 598 additions and 4 deletions.
104 changes: 100 additions & 4 deletions runtime/interpreter/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -1623,7 +1623,7 @@ func NewArrayValueWithIterator(
interpreter *Interpreter,
arrayType ArrayStaticType,
address common.Address,
count uint64,
countOverestimate uint64,
values func() Value,
) *ArrayValue {
interpreter.ReportComputation(common.ComputationKindCreateArrayValue, 1)
Expand Down Expand Up @@ -1668,7 +1668,7 @@ func NewArrayValueWithIterator(
return array
}
// must assign to v here for tracing to work properly
v = newArrayValueFromConstructor(interpreter, arrayType, count, constructor)
v = newArrayValueFromConstructor(interpreter, arrayType, countOverestimate, constructor)
return v
}

Expand All @@ -1685,14 +1685,14 @@ func newArrayValueFromAtreeValue(
func newArrayValueFromConstructor(
gauge common.MemoryGauge,
staticType ArrayStaticType,
count uint64,
countOverestimate uint64,
constructor func() *atree.Array,
) (array *ArrayValue) {
var elementSize uint
if staticType != nil {
elementSize = staticType.ElementType().elementSize()
}
baseUsage, elementUsage, dataSlabs, metaDataSlabs := common.NewArrayMemoryUsages(count, elementSize)
baseUsage, elementUsage, dataSlabs, metaDataSlabs := common.NewArrayMemoryUsages(countOverestimate, elementSize)
common.UseMemory(gauge, baseUsage)
common.UseMemory(gauge, elementUsage)
common.UseMemory(gauge, dataSlabs)
Expand Down Expand Up @@ -2459,6 +2459,29 @@ func (v *ArrayValue) GetMember(interpreter *Interpreter, locationRange LocationR
)
},
)

case sema.ArrayTypeFilterFunctionName:
return NewHostFunctionValue(
interpreter,
sema.ArrayFilterFunctionType(
interpreter,
v.SemaType(interpreter).ElementType(false),
),
func(invocation Invocation) Value {
interpreter := invocation.Interpreter

funcArgument, ok := invocation.Arguments[0].(FunctionValue)
if !ok {
panic(errors.NewUnreachableError())
}

return v.Filter(
interpreter,
invocation.LocationRange,
funcArgument,
)
},
)
}

return nil
Expand Down Expand Up @@ -2962,6 +2985,79 @@ func (v *ArrayValue) Reverse(
)
}

func (v *ArrayValue) Filter(
interpreter *Interpreter,
locationRange LocationRange,
procedure FunctionValue,
) Value {

elementTypeSlice := []sema.Type{v.semaType.ElementType(false)}
iterationInvocation := func(arrayElement Value) Invocation {
invocation := NewInvocation(
interpreter,
nil,
nil,
[]Value{arrayElement},
elementTypeSlice,
nil,
locationRange,
)
return invocation
}

iterator, err := v.array.Iterator()
if err != nil {
panic(errors.NewExternalError(err))
}

return NewArrayValueWithIterator(
interpreter,
NewVariableSizedStaticType(interpreter, v.Type.ElementType()),
common.ZeroAddress,
uint64(v.Count()), // worst case estimation.
func() Value {

var value Value

for {
atreeValue, err := iterator.Next()
if err != nil {
panic(errors.NewExternalError(err))
}

// Also handles the end of array case since iterator.Next() returns nil for that.
if atreeValue == nil {
return nil
}

value = MustConvertStoredValue(interpreter, atreeValue)
if value == nil {
return nil
}

shouldInclude, ok := procedure.invoke(iterationInvocation(value)).(BoolValue)
if !ok {
panic(errors.NewUnreachableError())
}

// We found the next entry of the filtered array.
if shouldInclude {
break
}
}

return value.Transfer(
interpreter,
locationRange,
atree.Address{},
false,
nil,
nil,
)
},
)
}

// NumberValue
type NumberValue interface {
ComparableValue
Expand Down
57 changes: 57 additions & 0 deletions runtime/sema/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,13 @@ Returns a new array with contents in the reversed order.
Available if the array element type is not resource-kinded.
`

const ArrayTypeFilterFunctionName = "filter"

const arrayTypeFilterFunctionDocString = `
Returns a new array whose elements are filtered by applying the filter function on each element of the original array.
Available if the array element type is not resource-kinded.
`

func getArrayMembers(arrayType ArrayType) map[string]MemberResolver {

members := map[string]MemberResolver{
Expand Down Expand Up @@ -1913,6 +1920,31 @@ func getArrayMembers(arrayType ArrayType) map[string]MemberResolver {
)
},
},
ArrayTypeFilterFunctionName: {
Kind: common.DeclarationKindFunction,
Resolve: func(memoryGauge common.MemoryGauge, identifier string, targetRange ast.Range, report func(error)) *Member {

elementType := arrayType.ElementType(false)

if elementType.IsResourceType() {
report(
&InvalidResourceArrayMemberError{
Name: identifier,
DeclarationKind: common.DeclarationKindFunction,
Range: targetRange,
},
)
}

return NewPublicFunctionMember(
memoryGauge,
arrayType,
identifier,
ArrayFilterFunctionType(memoryGauge, elementType),
arrayTypeFilterFunctionDocString,
)
},
},
}

// TODO: maybe still return members but report a helpful error?
Expand Down Expand Up @@ -2232,6 +2264,31 @@ func ArrayReverseFunctionType(arrayType ArrayType) *FunctionType {
}
}

func ArrayFilterFunctionType(memoryGauge common.MemoryGauge, elementType Type) *FunctionType {
// fun filter(_ function: ((T): Bool)): [T]
// funcType: elementType -> Bool
funcType := &FunctionType{
Parameters: []Parameter{
{
Identifier: "element",
TypeAnnotation: NewTypeAnnotation(elementType),
},
},
ReturnTypeAnnotation: NewTypeAnnotation(BoolType),
}

return &FunctionType{
Parameters: []Parameter{
{
Label: ArgumentLabelNotRequired,
Identifier: "f",
TypeAnnotation: NewTypeAnnotation(funcType),
},
},
ReturnTypeAnnotation: NewTypeAnnotation(NewVariableSizedType(memoryGauge, elementType)),
}
}

// VariableSizedType is a variable sized array type
type VariableSizedType struct {
Type Type
Expand Down
97 changes: 97 additions & 0 deletions runtime/tests/checker/arrays_dictionaries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,103 @@ func TestCheckResourceArrayReverseInvalid(t *testing.T) {
assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0])
}

func TestCheckArrayFilter(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
fun test() {
let x = [1, 2, 3]
let onlyEven =
fun (_ x: Int): Bool {
return x % 2 == 0
}
let y = x.filter(onlyEven)
}
fun testFixedSize() {
let x : [Int; 5] = [1, 2, 3, 21, 30]
let onlyEvenInt =
fun (_ x: Int): Bool {
return x % 2 == 0
}
let y = x.filter(onlyEvenInt)
}
`)

require.NoError(t, err)
}

func TestCheckArrayFilterInvalidArgs(t *testing.T) {

t.Parallel()

testInvalidArgs := func(code string, expectedErrors []sema.SemanticError) {
_, err := ParseAndCheck(t, code)

errs := RequireCheckerErrors(t, err, len(expectedErrors))

for i, e := range expectedErrors {
assert.IsType(t, e, errs[i])
}
}

testInvalidArgs(`
fun test() {
let x = [1, 2, 3]
let y = x.filter(100)
}
`,
[]sema.SemanticError{
&sema.TypeMismatchError{},
},
)

testInvalidArgs(`
fun test() {
let x = [1, 2, 3]
let onlyEvenInt16 =
fun (_ x: Int16): Bool {
return x % 2 == 0
}
let y = x.filter(onlyEvenInt16)
}
`,
[]sema.SemanticError{
&sema.TypeMismatchError{},
},
)
}

func TestCheckResourceArrayFilterInvalid(t *testing.T) {

t.Parallel()

_, err := ParseAndCheck(t, `
resource X {}
fun test(): @[X] {
let xs <- [<-create X()]
let allResources =
fun (_ x: @X): Bool {
destroy x
return true
}
let filteredXs <-xs.filter(allResources)
destroy xs
return <- filteredXs
}
`)

errs := RequireCheckerErrors(t, err, 1)

assert.IsType(t, &sema.InvalidResourceArrayMemberError{}, errs[0])
}

func TestCheckArrayContains(t *testing.T) {

t.Parallel()
Expand Down
Loading

0 comments on commit af119ed

Please sign in to comment.