diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index b80d96da8e..e1a60855df 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -1043,6 +1043,8 @@ var _ ValueIndexableValue = &StringValue{} var _ MemberAccessibleValue = &StringValue{} var _ IterableValue = &StringValue{} +var VarSizedArrayOfStringType = NewVariableSizedStaticType(nil, PrimitiveStaticTypeString) + func (v *StringValue) prepareGraphemes() { if v.graphemes == nil { v.graphemes = uniseg.NewGraphemes(v.Str) @@ -1342,6 +1344,20 @@ func (v *StringValue) GetMember(interpreter *Interpreter, locationRange Location return v.ToLower(invocation.Interpreter) }, ) + + case sema.StringTypeSplitFunctionName: + return NewHostFunctionValue( + interpreter, + sema.StringTypeSplitFunctionType, + func(invocation Invocation) Value { + separator, ok := invocation.Arguments[0].(*StringValue) + if !ok { + panic(errors.NewUnreachableError()) + } + + return v.Split(invocation.Interpreter, invocation.LocationRange, separator.Str) + }, + ) } return nil @@ -1396,6 +1412,35 @@ func (v *StringValue) ToLower(interpreter *Interpreter) *StringValue { ) } +func (v *StringValue) Split(inter *Interpreter, locationRange LocationRange, separator string) Value { + split := strings.Split(v.Str, separator) + + var index int + count := len(split) + + return NewArrayValueWithIterator( + inter, + VarSizedArrayOfStringType, + common.ZeroAddress, + uint64(count), + func() Value { + if index >= count { + return nil + } + + str := split[index] + index++ + return NewStringValue( + inter, + common.NewStringMemoryUsage(len(str)), + func() string { + return str + }, + ) + }, + ) +} + func (v *StringValue) Storable(storage atree.SlabStorage, address atree.Address, maxInlineSize uint64) (atree.Storable, error) { return maybeLargeImmutableStorable(v, storage, address, maxInlineSize) } diff --git a/runtime/sema/string_type.go b/runtime/sema/string_type.go index b4fa762599..691f14db13 100644 --- a/runtime/sema/string_type.go +++ b/runtime/sema/string_type.go @@ -42,6 +42,11 @@ const StringTypeJoinFunctionDocString = ` Returns a string after joining the array of strings with the provided separator. ` +const StringTypeSplitFunctionName = "split" +const StringTypeSplitFunctionDocString = ` +Returns a variable-sized array of strings after splitting the string on the delimiter. +` + // StringType represents the string type var StringType = &SimpleType{ Name: "String", @@ -105,6 +110,12 @@ func init() { StringTypeToLowerFunctionType, stringTypeToLowerFunctionDocString, ), + NewUnmeteredPublicFunctionMember( + t, + StringTypeSplitFunctionName, + StringTypeSplitFunctionType, + StringTypeSplitFunctionDocString, + ), }) } } @@ -335,3 +346,18 @@ var StringTypeJoinFunctionType = NewSimpleFunctionType( }, StringTypeAnnotation, ) + +var StringTypeSplitFunctionType = NewSimpleFunctionType( + FunctionPurityView, + []Parameter{ + { + Identifier: "separator", + TypeAnnotation: StringTypeAnnotation, + }, + }, + NewTypeAnnotation( + &VariableSizedType{ + Type: StringType, + }, + ), +) diff --git a/runtime/tests/checker/string_test.go b/runtime/tests/checker/string_test.go index 2ce9fc681b..5d94c0c077 100644 --- a/runtime/tests/checker/string_test.go +++ b/runtime/tests/checker/string_test.go @@ -408,3 +408,46 @@ func TestCheckStringJoinTypeMissingArgumentLabelSeparator(t *testing.T) { assert.IsType(t, &sema.MissingArgumentLabelError{}, errs[0]) } + +func TestCheckStringSplit(t *testing.T) { + + t.Parallel() + + checker, err := ParseAndCheck(t, ` + let s = "👪.❤️.Abc".split(separator: ".") + `) + require.NoError(t, err) + + assert.Equal(t, + &sema.VariableSizedType{ + Type: sema.StringType, + }, + RequireGlobalValue(t, checker.Elaboration, "s"), + ) +} + +func TestCheckStringSplitTypeMismatchSeparator(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + let s = "Abc:1".split(separator: 1234) + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.TypeMismatchError{}, errs[0]) +} + +func TestCheckStringSplitTypeMissingArgumentLabelSeparator(t *testing.T) { + + t.Parallel() + + _, err := ParseAndCheck(t, ` + let s = "👪Abc".split("/") + `) + + errs := RequireCheckerErrors(t, err, 1) + + assert.IsType(t, &sema.MissingArgumentLabelError{}, errs[0]) +} diff --git a/runtime/tests/interpreter/string_test.go b/runtime/tests/interpreter/string_test.go index 51873a5a79..37bba83186 100644 --- a/runtime/tests/interpreter/string_test.go +++ b/runtime/tests/interpreter/string_test.go @@ -499,3 +499,100 @@ func TestInterpretStringJoin(t *testing.T) { testCase(t, "testEmptyArray", interpreter.NewUnmeteredStringValue("")) testCase(t, "testSingletonArray", interpreter.NewUnmeteredStringValue("pqrS")) } + +func TestInterpretStringSplit(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + fun split(): [String] { + return "👪////❤️".split(separator: "////") + } + fun splitBySpace(): [String] { + return "👪 ❤️ Abc6 ;123".split(separator: " ") + } + fun splitWithUnicodeEquivalence(): [String] { + return "Caf\u{65}\u{301}ABc".split(separator: "\u{e9}") + } + fun testEmptyString(): [String] { + return "".split(separator: "//") + } + fun testNoMatch(): [String] { + return "pqrS;asdf".split(separator: ";;") + } + `) + + testCase := func(t *testing.T, funcName string, expected *interpreter.ArrayValue) { + t.Run(funcName, func(t *testing.T) { + result, err := inter.Invoke(funcName) + require.NoError(t, err) + + RequireValuesEqual( + t, + inter, + expected, + result, + ) + }) + } + + varSizedStringType := &interpreter.VariableSizedStaticType{ + Type: interpreter.PrimitiveStaticTypeString, + } + + testCase(t, + "split", + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + varSizedStringType, + common.ZeroAddress, + interpreter.NewUnmeteredStringValue("👪"), + interpreter.NewUnmeteredStringValue("❤️"), + ), + ) + testCase(t, + "splitBySpace", + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + varSizedStringType, + common.ZeroAddress, + interpreter.NewUnmeteredStringValue("👪"), + interpreter.NewUnmeteredStringValue("❤️"), + interpreter.NewUnmeteredStringValue("Abc6"), + interpreter.NewUnmeteredStringValue(";123"), + ), + ) + testCase(t, + "splitWithUnicodeEquivalence", + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + varSizedStringType, + common.ZeroAddress, + interpreter.NewUnmeteredStringValue("Caf"), + interpreter.NewUnmeteredStringValue("ABc"), + ), + ) + testCase(t, + "testEmptyString", + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + varSizedStringType, + common.ZeroAddress, + interpreter.NewUnmeteredStringValue(""), + ), + ) + testCase(t, + "testNoMatch", + interpreter.NewArrayValue( + inter, + interpreter.EmptyLocationRange, + varSizedStringType, + common.ZeroAddress, + interpreter.NewUnmeteredStringValue("pqrS;asdf"), + ), + ) +}