Skip to content

Commit 12e04d5

Browse files
authored
fix: Correctly iterate over null struct values in ManagedArrowTable (#2209)
Fixes internal issue 446726636 🦕
1 parent 9b86dcf commit 12e04d5

File tree

4 files changed

+53
-47
lines changed

4 files changed

+53
-47
lines changed

bigframes/core/local_data.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,16 @@ def _(
253253
value_generator = iter_array(
254254
array.flatten(), bigframes.dtypes.get_array_inner_type(dtype)
255255
)
256-
for (start, end) in _pairwise(array.offsets):
257-
arr_size = end.as_py() - start.as_py()
258-
yield list(itertools.islice(value_generator, arr_size))
256+
offset_generator = iter_array(array.offsets, bigframes.dtypes.INT_DTYPE)
257+
258+
start_offset = None
259+
end_offset = None
260+
for offset in offset_generator:
261+
start_offset = end_offset
262+
end_offset = offset
263+
if start_offset is not None:
264+
arr_size = end_offset - start_offset
265+
yield list(itertools.islice(value_generator, arr_size))
259266

260267
@iter_array.register
261268
def _(
@@ -267,8 +274,15 @@ def _(
267274
sub_generators[field_name] = iter_array(array.field(field_name), dtype)
268275

269276
keys = list(sub_generators.keys())
270-
for row_values in zip(*sub_generators.values()):
271-
yield {key: value for key, value in zip(keys, row_values)}
277+
is_null_generator = iter_array(array.is_null(), bigframes.dtypes.BOOL_DTYPE)
278+
279+
for values in zip(is_null_generator, *sub_generators.values()):
280+
is_row_null = values[0]
281+
row_values = values[1:]
282+
if not is_row_null:
283+
yield {key: value for key, value in zip(keys, row_values)}
284+
else:
285+
yield None
272286

273287
for batch in table.to_batches():
274288
sub_generators: dict[str, Generator[Any, None, None]] = {}
@@ -491,16 +505,3 @@ def _schema_durations_to_ints(schema: pa.Schema) -> pa.Schema:
491505
return pa.schema(
492506
pa.field(field.name, _durations_to_ints(field.type)) for field in schema
493507
)
494-
495-
496-
def _pairwise(iterable):
497-
do_yield = False
498-
a = None
499-
b = None
500-
for item in iterable:
501-
a = b
502-
b = item
503-
if do_yield:
504-
yield (a, b)
505-
else:
506-
do_yield = True

tests/system/small/engines/test_read_local.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ def test_engines_read_local_w_zero_row_source(
8888
assert_equivalence_execution(local_node, REFERENCE_ENGINE, engine)
8989

9090

91-
# TODO: Fix sqlglot impl
92-
@pytest.mark.parametrize("engine", ["polars", "bq", "pyarrow"], indirect=True)
91+
@pytest.mark.parametrize(
92+
"engine", ["polars", "bq", "pyarrow", "bq-sqlglot"], indirect=True
93+
)
9394
def test_engines_read_local_w_nested_source(
9495
fake_session: bigframes.Session,
9596
nested_data_source: local_data.ManagedArrowTable,

tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_nested_structs_df/out.sql

Lines changed: 0 additions & 19 deletions
This file was deleted.

tests/unit/test_local_data.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,21 @@
2020

2121
pd_data = pd.DataFrame(
2222
{
23-
"ints": [10, 20, 30, 40],
24-
"nested_ints": [[1, 2], [3, 4, 5], [], [20, 30]],
25-
"structs": [{"a": 100}, {}, {"b": 200}, {"b": 300}],
23+
"ints": [10, 20, 30, 40, 50],
24+
"nested_ints": [[1, 2], [], [3, 4, 5], [], [20, 30]],
25+
"structs": [{"a": 100}, None, {}, {"b": 200}, {"b": 300}],
2626
}
2727
)
2828

2929
pd_data_normalized = pd.DataFrame(
3030
{
31-
"ints": pd.Series([10, 20, 30, 40], dtype=dtypes.INT_DTYPE),
31+
"ints": pd.Series([10, 20, 30, 40, 50], dtype=dtypes.INT_DTYPE),
3232
"nested_ints": pd.Series(
33-
[[1, 2], [3, 4, 5], [], [20, 30]], dtype=pd.ArrowDtype(pa.list_(pa.int64()))
33+
[[1, 2], [], [3, 4, 5], [], [20, 30]],
34+
dtype=pd.ArrowDtype(pa.list_(pa.int64())),
3435
),
3536
"structs": pd.Series(
36-
[{"a": 100}, {}, {"b": 200}, {"b": 300}],
37+
[{"a": 100}, None, {}, {"b": 200}, {"b": 300}],
3738
dtype=pd.ArrowDtype(pa.struct({"a": pa.int64(), "b": pa.int64()})),
3839
),
3940
}
@@ -122,11 +123,11 @@ def test_local_data_well_formed_round_trip_chunked():
122123

123124
def test_local_data_well_formed_round_trip_sliced():
124125
pa_table = pa.Table.from_pandas(pd_data, preserve_index=False)
125-
as_rechunked_pyarrow = pa.Table.from_batches(pa_table.slice(2, 4).to_batches())
126+
as_rechunked_pyarrow = pa.Table.from_batches(pa_table.slice(0, 4).to_batches())
126127
local_entry = local_data.ManagedArrowTable.from_pyarrow(as_rechunked_pyarrow)
127128
result = pd.DataFrame(local_entry.itertuples(), columns=pd_data.columns)
128129
pandas.testing.assert_frame_equal(
129-
pd_data_normalized[2:4].reset_index(drop=True),
130+
pd_data_normalized[0:4].reset_index(drop=True),
130131
result.reset_index(drop=True),
131132
check_dtype=False,
132133
)
@@ -143,3 +144,25 @@ def test_local_data_not_equal_other():
143144
local_entry2 = local_data.ManagedArrowTable.from_pandas(pd_data[::2])
144145
assert local_entry != local_entry2
145146
assert hash(local_entry) != hash(local_entry2)
147+
148+
149+
def test_local_data_itertuples_struct_none():
150+
pd_data = pd.DataFrame(
151+
{
152+
"structs": [{"a": 100}, None, {"b": 200}, {"b": 300}],
153+
}
154+
)
155+
local_entry = local_data.ManagedArrowTable.from_pandas(pd_data)
156+
result = list(local_entry.itertuples())
157+
assert result[1][0] is None
158+
159+
160+
def test_local_data_itertuples_list_none():
161+
pd_data = pd.DataFrame(
162+
{
163+
"lists": [[1, 2], None, [3, 4]],
164+
}
165+
)
166+
local_entry = local_data.ManagedArrowTable.from_pandas(pd_data)
167+
result = list(local_entry.itertuples())
168+
assert result[1][0] == []

0 commit comments

Comments
 (0)