-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
BUG (string): ArrowStringArray.find corner cases #59562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7a99bdb
f7f19d3
d9f0aa7
f11921e
c34ae46
86ef129
472f17a
e4c782c
e1b7913
8f07638
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -416,18 +416,14 @@ def _str_count(self, pat: str, flags: int = 0): | |
return self._convert_int_result(result) | ||
|
||
def _str_find(self, sub: str, start: int = 0, end: int | None = None): | ||
if start != 0 and end is not None: | ||
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end) | ||
result = pc.find_substring(slices, sub) | ||
not_found = pc.equal(result, -1) | ||
offset_result = pc.add(result, end - start) | ||
result = pc.if_else(not_found, result, offset_result) | ||
elif start == 0 and end is None: | ||
slices = self._pa_array | ||
result = pc.find_substring(slices, sub) | ||
else: | ||
if ( | ||
pa_version_under13p0 | ||
and not (start != 0 and end is not None) | ||
and not (start == 0 and end is None) | ||
): | ||
# GH#59562 | ||
return super()._str_find(sub, start, end) | ||
return self._convert_int_result(result) | ||
return ArrowStringArrayMixin._str_find(self, sub, start, end) | ||
Comment on lines
+419
to
+426
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that this special case is moved in the mixin method, I would expect this can be removed entirely? (and replaced with a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this goes through a cython path instead of iterating in python There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, through There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, changed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like doing this broke the min_versions build, so reverted |
||
|
||
def _str_get_dummies(self, sep: str = "|"): | ||
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,8 +32,6 @@ | |
import numpy as np | ||
import pytest | ||
|
||
from pandas._config import using_string_dtype | ||
|
||
from pandas._libs import lib | ||
from pandas._libs.tslibs import timezones | ||
from pandas.compat import ( | ||
|
@@ -1947,14 +1945,9 @@ def test_str_find_negative_start(): | |
|
||
def test_str_find_no_end(): | ||
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) | ||
if pa_version_under13p0: | ||
# https://github.com/apache/arrow/issues/36311 | ||
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"): | ||
ser.str.find("ab", start=1) | ||
else: | ||
result = ser.str.find("ab", start=1) | ||
expected = pd.Series([-1, None], dtype="int64[pyarrow]") | ||
tm.assert_series_equal(result, expected) | ||
result = ser.str.find("ab", start=1) | ||
expected = pd.Series([-1, None], dtype="int64[pyarrow]") | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
def test_str_find_negative_start_negative_end(): | ||
|
@@ -1968,17 +1961,11 @@ def test_str_find_negative_start_negative_end(): | |
def test_str_find_large_start(): | ||
# GH 56791 | ||
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) | ||
if pa_version_under13p0: | ||
# https://github.com/apache/arrow/issues/36311 | ||
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"): | ||
ser.str.find(sub="d", start=16) | ||
else: | ||
result = ser.str.find(sub="d", start=16) | ||
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) | ||
tm.assert_series_equal(result, expected) | ||
result = ser.str.find(sub="d", start=16) | ||
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False) | ||
@pytest.mark.skipif( | ||
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311" | ||
) | ||
|
@@ -1990,11 +1977,15 @@ def test_str_find_e2e(start, end, sub): | |
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""], | ||
dtype=ArrowDtype(pa.string()), | ||
) | ||
object_series = s.astype(pd.StringDtype()) | ||
object_series = s.astype(pd.StringDtype(storage="python")) | ||
result = s.str.find(sub, start, end) | ||
expected = object_series.str.find(sub, start, end).astype(result.dtype) | ||
tm.assert_series_equal(result, expected) | ||
|
||
arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow")) | ||
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype) | ||
tm.assert_series_equal(result2, expected) | ||
Comment on lines
+1985
to
+1987
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For future PRs, we should add such tests to |
||
|
||
|
||
def test_str_find_negative_start_negative_end_no_match(): | ||
# GH 56791 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be implemented for ArrowStringArray as well then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is only used for the ArrowEA version. The ArrowStringArray goes through _str_map, which ArrowEA doesn't have. eventually id like to align the names, but there are too many branches/PRs as it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what I am missing, but
_apply_elementwise
is called from the now-shared_str_find
method just below, and so I would think that you can also get there fromArrowStringArray._str_find
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep my bad. ArrowStringArray inherits ArrowEA so gets its apply_elementwise from there. putting it here just prevents mypy from complaining