Skip to content

Commit b9bdca8

Browse files
authored
fix: dtype parameter ineffective in Series/DataFrame construction (#1354)
* fix: dtype parameter ineffective in Series IO * Revert "docs: update struct examples. (#953)" This reverts commit d632cd0. * skip array tests because of dtype mismatches
1 parent 7ae565d commit b9bdca8

File tree

5 files changed

+116
-25
lines changed

5 files changed

+116
-25
lines changed

bigframes/dtypes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,10 @@ def is_object_like(type_: Union[ExpressionType, str]) -> bool:
295295
# See: https://stackoverflow.com/a/40312924/101923 and
296296
# https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html
297297
# for the way to identify object type.
298-
return type_ in ("object", "O") or getattr(type_, "kind", None) == "O"
298+
return type_ in ("object", "O") or (
299+
getattr(type_, "kind", None) == "O"
300+
and getattr(type_, "storage", None) != "pyarrow"
301+
)
299302

300303

301304
def is_string_like(type_: ExpressionType) -> bool:

tests/system/small/bigquery/test_json.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import geopandas as gpd # type: ignore
1818
import pandas as pd
19+
import pyarrow as pa
1920
import pytest
2021

2122
import bigframes.bigquery as bbq
@@ -174,7 +175,7 @@ def test_json_extract_array_from_json_strings():
174175
actual = bbq.json_extract_array(s, "$.a")
175176
expected = bpd.Series(
176177
[['"ab"', '"2"', '"3 xy"'], [], ['"4"', '"5"'], None],
177-
dtype=pd.StringDtype(storage="pyarrow"),
178+
dtype=pd.ArrowDtype(pa.list_(pa.string())),
178179
)
179180
pd.testing.assert_series_equal(
180181
actual.to_pandas(),
@@ -190,7 +191,7 @@ def test_json_extract_array_from_json_array_strings():
190191
actual = bbq.json_extract_array(s)
191192
expected = bpd.Series(
192193
[["1", "2", "3"], [], ["4", "5"]],
193-
dtype=pd.StringDtype(storage="pyarrow"),
194+
dtype=pd.ArrowDtype(pa.list_(pa.string())),
194195
)
195196
pd.testing.assert_series_equal(
196197
actual.to_pandas(),

tests/system/small/test_dataframe.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,19 @@ def test_df_construct_inline_respects_location():
166166
assert table.location == "europe-west1"
167167

168168

169+
def test_df_construct_dtype():
170+
data = {
171+
"int_col": [1, 2, 3],
172+
"string_col": ["1.1", "2.0", "3.5"],
173+
"float_col": [1.0, 2.0, 3.0],
174+
}
175+
dtype = pd.StringDtype(storage="pyarrow")
176+
bf_result = dataframe.DataFrame(data, dtype=dtype)
177+
pd_result = pd.DataFrame(data, dtype=dtype)
178+
pd_result.index = pd_result.index.astype("Int64")
179+
pandas.testing.assert_frame_equal(bf_result.to_pandas(), pd_result)
180+
181+
169182
def test_get_column(scalars_dfs):
170183
scalars_df, scalars_pandas_df = scalars_dfs
171184
col_name = "int64_col"

tests/system/small/test_series.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import pytest
2727
import shapely # type: ignore
2828

29+
import bigframes.features
2930
import bigframes.pandas
3031
import bigframes.series as series
3132
from tests.system.utils import (
@@ -228,6 +229,79 @@ def test_series_construct_geodata():
228229
)
229230

230231

232+
@pytest.mark.parametrize(
233+
("dtype"),
234+
[
235+
pytest.param(pd.Int64Dtype(), id="int"),
236+
pytest.param(pd.Float64Dtype(), id="float"),
237+
pytest.param(pd.StringDtype(storage="pyarrow"), id="string"),
238+
],
239+
)
240+
def test_series_construct_w_dtype_for_int(dtype):
241+
data = [1, 2, 3]
242+
expected = pd.Series(data, dtype=dtype)
243+
expected.index = expected.index.astype("Int64")
244+
series = bigframes.pandas.Series(data, dtype=dtype)
245+
pd.testing.assert_series_equal(series.to_pandas(), expected)
246+
247+
248+
def test_series_construct_w_dtype_for_struct():
249+
# The data shows the struct fields are disordered and correctly handled during
250+
# construction.
251+
data = [
252+
{"a": 1, "c": "pandas", "b": dt.datetime(2020, 1, 20, 20, 20, 20, 20)},
253+
{"a": 2, "c": "pandas", "b": dt.datetime(2019, 1, 20, 20, 20, 20, 20)},
254+
{"a": 1, "c": "numpy", "b": None},
255+
]
256+
dtype = pd.ArrowDtype(
257+
pa.struct([("a", pa.int64()), ("c", pa.string()), ("b", pa.timestamp("us"))])
258+
)
259+
series = bigframes.pandas.Series(data, dtype=dtype)
260+
expected = pd.Series(data, dtype=dtype)
261+
expected.index = expected.index.astype("Int64")
262+
pd.testing.assert_series_equal(series.to_pandas(), expected)
263+
264+
265+
def test_series_construct_w_dtype_for_array_string():
266+
data = [["1", "2", "3"], [], ["4", "5"]]
267+
dtype = pd.ArrowDtype(pa.list_(pa.string()))
268+
series = bigframes.pandas.Series(data, dtype=dtype)
269+
expected = pd.Series(data, dtype=dtype)
270+
expected.index = expected.index.astype("Int64")
271+
272+
# Skip dtype check due to internal issue b/321013333. This issue causes array types
273+
# to be converted to the `object` dtype when calling `to_pandas()`, resulting in
274+
# a mismatch with the expected Pandas type.
275+
if bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable:
276+
check_dtype = True
277+
else:
278+
check_dtype = False
279+
280+
pd.testing.assert_series_equal(
281+
series.to_pandas(), expected, check_dtype=check_dtype
282+
)
283+
284+
285+
def test_series_construct_w_dtype_for_array_struct():
286+
data = [[{"a": 1, "c": "aa"}, {"a": 2, "c": "bb"}], [], [{"a": 3, "c": "cc"}]]
287+
dtype = pd.ArrowDtype(pa.list_(pa.struct([("a", pa.int64()), ("c", pa.string())])))
288+
series = bigframes.pandas.Series(data, dtype=dtype)
289+
expected = pd.Series(data, dtype=dtype)
290+
expected.index = expected.index.astype("Int64")
291+
292+
# Skip dtype check due to internal issue b/321013333. This issue causes array types
293+
# to be converted to the `object` dtype when calling `to_pandas()`, resulting in
294+
# a mismatch with the expected Pandas type.
295+
if bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable:
296+
check_dtype = True
297+
else:
298+
check_dtype = False
299+
300+
pd.testing.assert_series_equal(
301+
series.to_pandas(), expected, check_dtype=check_dtype
302+
)
303+
304+
231305
def test_series_keys(scalars_dfs):
232306
scalars_df, scalars_pandas_df = scalars_dfs
233307
bf_result = scalars_df["int64_col"].keys().to_pandas()

third_party/bigframes_vendored/pandas/core/arrays/arrow/accessors.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,12 @@ def field(self, name_or_index: str | int):
8787
>>> bpd.options.display.progress_bar = None
8888
>>> s = bpd.Series(
8989
... [
90-
... {"project": "pandas", "version": 1},
91-
... {"project": "pandas", "version": 2},
92-
... {"project": "numpy", "version": 1},
90+
... {"version": 1, "project": "pandas"},
91+
... {"version": 2, "project": "pandas"},
92+
... {"version": 1, "project": "numpy"},
9393
... ],
9494
... dtype=bpd.ArrowDtype(pa.struct(
95-
... [("project", pa.string()), ("version", pa.int64())]
95+
... [("version", pa.int64()), ("project", pa.string())]
9696
... ))
9797
... )
9898
@@ -106,7 +106,7 @@ def field(self, name_or_index: str | int):
106106
107107
Extract by field index.
108108
109-
>>> s.struct.field(1)
109+
>>> s.struct.field(0)
110110
0 1
111111
1 2
112112
2 1
@@ -133,22 +133,22 @@ def explode(self):
133133
>>> bpd.options.display.progress_bar = None
134134
>>> s = bpd.Series(
135135
... [
136-
... {"project": "pandas", "version": 1},
137-
... {"project": "pandas", "version": 2},
138-
... {"project": "numpy", "version": 1},
136+
... {"version": 1, "project": "pandas"},
137+
... {"version": 2, "project": "pandas"},
138+
... {"version": 1, "project": "numpy"},
139139
... ],
140140
... dtype=bpd.ArrowDtype(pa.struct(
141-
... [("project", pa.string()), ("version", pa.int64())]
141+
... [("version", pa.int64()), ("project", pa.string())]
142142
... ))
143143
... )
144144
145145
Extract all child fields.
146146
147147
>>> s.struct.explode()
148-
project version
149-
0 pandas 1
150-
1 pandas 2
151-
2 numpy 1
148+
version project
149+
0 1 pandas
150+
1 2 pandas
151+
2 1 numpy
152152
<BLANKLINE>
153153
[3 rows x 2 columns]
154154
@@ -178,8 +178,8 @@ def dtypes(self):
178178
... ))
179179
... )
180180
>>> s.struct.dtypes()
181-
project string[pyarrow]
182181
version Int64
182+
project string[pyarrow]
183183
dtype: object
184184
185185
Returns:
@@ -205,21 +205,21 @@ def explode(self, column, *, separator: str = "."):
205205
>>> countries = bpd.Series(["cn", "es", "us"])
206206
>>> files = bpd.Series(
207207
... [
208-
... {"project": "pandas", "version": 1},
209-
... {"project": "pandas", "version": 2},
210-
... {"project": "numpy", "version": 1},
208+
... {"version": 1, "project": "pandas"},
209+
... {"version": 2, "project": "pandas"},
210+
... {"version": 1, "project": "numpy"},
211211
... ],
212212
... dtype=bpd.ArrowDtype(pa.struct(
213-
... [("project", pa.string()), ("version", pa.int64())]
213+
... [("version", pa.int64()), ("project", pa.string())]
214214
... ))
215215
... )
216216
>>> downloads = bpd.Series([100, 200, 300])
217217
>>> df = bpd.DataFrame({"country": countries, "file": files, "download_count": downloads})
218218
>>> df.struct.explode("file")
219-
country file.project file.version download_count
220-
0 cn pandas 1 100
221-
1 es pandas 2 200
222-
2 us numpy 1 300
219+
country file.version file.project download_count
220+
0 cn 1 pandas 100
221+
1 es 2 pandas 200
222+
2 us 1 numpy 300
223223
<BLANKLINE>
224224
[3 rows x 4 columns]
225225

0 commit comments

Comments
 (0)