diff --git a/petastorm/arrow_reader_worker.py b/petastorm/arrow_reader_worker.py index 6f3ae3e3..2b602d99 100644 --- a/petastorm/arrow_reader_worker.py +++ b/petastorm/arrow_reader_worker.py @@ -62,7 +62,7 @@ def read_next(self, workers_pool, schema, ngram): column_as_numpy = column_as_pandas if pa.types.is_string(column.type): - result_dict[column_name] = column_as_numpy.astype(np.unicode_) + result_dict[column_name] = column_as_numpy.astype(np.object) elif pa.types.is_list(column.type) or pa.types.is_fixed_size_list(column.type): # Assuming all lists are of the same length, hence we can collate them into a matrix list_of_lists = column_as_numpy diff --git a/petastorm/tests/test_parquet_reader.py b/petastorm/tests/test_parquet_reader.py index a9afd7ef..203a8316 100644 --- a/petastorm/tests/test_parquet_reader.py +++ b/petastorm/tests/test_parquet_reader.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd +import pyarrow as pa import pytest from pyarrow import parquet as pq @@ -223,6 +224,38 @@ def fill_id_with_nones(x): assert sample.id.dtype.type == null_column_dtype +@pytest.mark.parametrize('np_dtype, pa_dtype, null_value', + ((np.float32, pa.float32(), np.nan), (np.object, pa.string(), None))) +@pytest.mark.parametrize('reader_factory', _D) +def test_entire_column_of_typed_nulls(reader_factory, np_dtype, pa_dtype, null_value, tmp_path): + path = tmp_path / "dataset" + schema = pa.schema([pa.field('all_nulls', pa_dtype)]) + pq.write_table(pa.Table.from_pydict({"all_nulls": [null_value] * 10}, schema=schema), path) + + with reader_factory("file:///" + str(path)) as reader: + sample = next(reader) + assert sample.all_nulls.dtype == np_dtype + if np_dtype == np.float32: + assert np.all(np.isnan(sample.all_nulls)) + elif np_dtype == np.object: + assert all(v is None for v in sample.all_nulls) + else: + assert False, "Unexpected np_dtype" + + +@pytest.mark.parametrize('reader_factory', _D) +def test_column_with_list_of_strings_some_are_null(reader_factory, tmp_path): + path = tmp_path / "dataset" + schema = pa.schema([pa.field('some_nulls', pa.list_(pa.string(), -1))]) + pq.write_table(pa.Table.from_pydict({"some_nulls": [['a0', 'a1'], ['b0', None], [None, None]]}, schema=schema), + path) + + with reader_factory("file:///" + str(path)) as reader: + sample = next(reader) + assert sample.some_nulls.dtype == np.object + np.testing.assert_equal(sample.some_nulls, [['a0', 'a1'], ['b0', None], [None, None]]) + + @pytest.mark.parametrize('reader_factory', _D) def test_transform_spec_returns_all_none_values_in_a_list_field(scalar_dataset, reader_factory): def fill_id_with_nones(x):