Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Add String __add__ with StringSlice and StringLiteral & optimize to not use List.resize() #3591

Closed
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ repos:
stages: [commit]
- id: check-license
name: check-license
entry: mojo stdlib/scripts/check-licenses.mojo
entry: mojo stdlib/scripts/check_licenses.mojo
language: system
files: '\.(mojo|🔥|py)$'
stages: [commit]
Expand Down
155 changes: 118 additions & 37 deletions stdlib/src/collections/string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,30 @@ struct String(
"""
return not (self < rhs)

@staticmethod
fn _add[rhs_has_null: Bool](lhs: Span[UInt8], rhs: Span[UInt8]) -> String:
var lhs_len = len(lhs)
var rhs_len = len(rhs)
var lhs_ptr = lhs.unsafe_ptr()
var rhs_ptr = rhs.unsafe_ptr()
alias S = StringSlice[ImmutableAnyOrigin]
if lhs_len == 0:
return String(S(unsafe_from_utf8_ptr=rhs_ptr, len=rhs_len))
elif rhs_len == 0:
return String(S(unsafe_from_utf8_ptr=lhs_ptr, len=lhs_len))
var sum_len = lhs_len + rhs_len
var buffer = Self._buffer_type(capacity=sum_len + 1)
var ptr = buffer.unsafe_ptr()
memcpy(ptr, lhs_ptr, lhs_len)
memcpy(ptr + lhs_len, rhs_ptr, rhs_len + int(rhs_has_null))
buffer.size = sum_len + 1

@parameter
if not rhs_has_null:
ptr[sum_len] = 0
return Self(buffer^)

@always_inline
fn __add__(self, other: String) -> String:
"""Creates a string by appending another string at the end.

Expand All @@ -1072,26 +1096,31 @@ struct String(
Returns:
The new constructed string.
"""
if not self:
return other
if not other:
return self
var self_len = self.byte_length()
var other_len = other.byte_length()
var total_len = self_len + other_len
var buffer = Self._buffer_type()
buffer.resize(total_len + 1, 0)
memcpy(
buffer.data,
self.unsafe_ptr(),
self_len,
)
memcpy(
buffer.data + self_len,
other.unsafe_ptr(),
other_len + 1, # Also copy the terminator
)
return Self(buffer^)
return Self._add[True](self.as_bytes(), other.as_bytes())

@always_inline
fn __add__(self, other: StringLiteral) -> String:
"""Creates a string by appending a string literal at the end.

Args:
other: The string literal to append.

Returns:
The new constructed string.
"""
return Self._add[False](self.as_bytes(), other.as_bytes())

@always_inline
fn __add__(self, other: StringSlice) -> String:
"""Creates a string by appending a string slice at the end.

Args:
other: The string slice to append.

Returns:
The new constructed string.
"""
return Self._add[False](self.as_bytes(), other.as_bytes())

@always_inline
fn __radd__(self, other: String) -> String:
Expand All @@ -1103,30 +1132,80 @@ struct String(
Returns:
The new constructed string.
"""
return other + self
return Self._add[True](other.as_bytes(), self.as_bytes())

@always_inline
fn __radd__(self, other: StringLiteral) -> String:
"""Creates a string by prepending another string literal to the start.

Args:
other: The string to prepend.

Returns:
The new constructed string.
"""
return Self._add[True](other.as_bytes(), self.as_bytes())

@always_inline
fn __radd__(self, other: StringSlice) -> String:
"""Creates a string by prepending another string slice to the start.

Args:
other: The string to prepend.

Returns:
The new constructed string.
"""
return Self._add[True](other.as_bytes(), self.as_bytes())

fn _iadd[has_null: Bool](inout self, other: Span[UInt8]):
var s_len = self.byte_length()
var o_len = len(other)
var o_ptr = other.unsafe_ptr()
if s_len == 0:
alias S = StringSlice[ImmutableAnyOrigin]
self = String(S(unsafe_from_utf8_ptr=o_ptr, len=o_len))
return
elif o_len == 0:
return
var sum_len = s_len + o_len
self._buffer.reserve(sum_len + 1)
var s_ptr = self.unsafe_ptr()
memcpy(s_ptr + s_len, o_ptr, o_len + int(has_null))
self._buffer.size = sum_len + 1

@parameter
if not has_null:
s_ptr[sum_len] = 0

@always_inline
fn __iadd__(inout self, other: String):
"""Appends another string to this string.

Args:
other: The string to append.
"""
if not self:
self = other
return
if not other:
return
var self_len = self.byte_length()
var other_len = other.byte_length()
var total_len = self_len + other_len
self._buffer.resize(total_len + 1, 0)
# Copy the data alongside the terminator.
memcpy(
dest=self.unsafe_ptr() + self_len,
src=other.unsafe_ptr(),
count=other_len + 1,
)
self._iadd[True](other.as_bytes())

@always_inline
fn __iadd__(inout self, other: StringLiteral):
"""Appends another string literal to this string.

Args:
other: The string to append.
"""
self._iadd[False](other.as_bytes())

@always_inline
fn __iadd__(inout self, other: StringSlice):
"""Appends another string slice to this string.

Args:
other: The string to append.
"""
self._iadd[False](other.as_bytes())

@always_inline
fn __iter__(ref [_]self) -> _StringSliceIter[__origin_of(self)]:
"""Iterate over elements of the string, returning immutable references.

Expand All @@ -1137,6 +1216,7 @@ struct String(
unsafe_pointer=self.unsafe_ptr(), length=self.byte_length()
)

@always_inline
fn __reversed__(
ref [_]self,
) -> _StringSliceIter[__origin_of(self), False]:
Expand Down Expand Up @@ -1473,7 +1553,8 @@ struct String(
Notes:
This does not include the trailing null terminator in the count.
"""
return max(len(self._buffer) - 1, 0)
var length = len(self._buffer)
return length - int(length > 0)

fn _steal_ptr(inout self) -> UnsafePointer[UInt8]:
"""Transfer ownership of pointer to the underlying memory.
Expand Down
16 changes: 15 additions & 1 deletion stdlib/test/collections/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ from testing import (
assert_true,
)

from utils import StringRef
from utils import StringRef, StringSlice


@value
Expand Down Expand Up @@ -193,6 +193,19 @@ def test_add():
assert_equal("abc is a string", str(s8) + str(s9))


def test_add_string_slice():
var s1 = String("123")
var s2 = StringSlice("abc")
var s3: StringLiteral = "abc"
assert_equal("123abc", s1 + s2)
assert_equal("123abc", s1 + s3)
assert_equal("abc123", s2 + s1)
assert_equal("abc123", s3 + s1)
s1 += s2
s1 += s3
assert_equal("123abcabc", s1)


def test_string_join():
var sep = String(",")
var s0 = String("abc")
Expand Down Expand Up @@ -1623,6 +1636,7 @@ def main():
test_equality_operators()
test_comparison_operators()
test_add()
test_add_string_slice()
test_stringable()
test_repr()
test_string_join()
Expand Down