Skip to content

Commit

Permalink
Extend mapping of nullable types for pandas (#278)
Browse files Browse the repository at this point in the history
* Extend mapping of nullable types for pandas

* Add test for `nullable_series` expected output types
  • Loading branch information
oliverholworthy authored Apr 12, 2023
1 parent 4378e9f commit dd98a43
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
8 changes: 8 additions & 0 deletions merlin/core/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,16 @@ def inner2(*args, **kwargs):
# Define mapping between non-nullable,
# and nullable types in Pandas
_PD_NULLABLE_MAP = {
"float32": "Float32",
"float64": "Float64",
"int8": "Int8",
"int16": "Int16",
"int32": "Int32",
"int64": "Int64",
"uint8": "UInt8",
"uint16": "UInt16",
"uint32": "UInt32",
"uint64": "UInt64",
}


Expand Down
28 changes: 27 additions & 1 deletion tests/unit/core/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@

from merlin.core.compat import HAS_GPU
from merlin.core.compat import cupy as cp
from merlin.core.dispatch import concat_columns, is_list_dtype, list_val_dtype, make_df
from merlin.core.dispatch import (
concat_columns,
is_list_dtype,
list_val_dtype,
make_df,
nullable_series,
)

if HAS_GPU:
_DEVICES = ["cpu", "gpu"]
Expand Down Expand Up @@ -59,3 +65,23 @@ def test_pandas_cupy_combo():
pd_df = pd.DataFrame(rand_cp_nd_arr.get())[0]
mk_df = make_df(rand_cp_nd_arr)[0]
assert all(pd_df.to_numpy() == mk_df.to_numpy())


@pytest.mark.parametrize(
["data", "dtype", "expected_series"],
[
[[None], np.dtype("int8"), pd.Series([pd.NA], dtype="Int8")],
[[None], np.dtype("int16"), pd.Series([pd.NA], dtype="Int16")],
[[None], np.dtype("int32"), pd.Series([pd.NA], dtype="Int32")],
[[None], np.dtype("int64"), pd.Series([pd.NA], dtype="Int64")],
[[None], np.dtype("uint8"), pd.Series([pd.NA], dtype="UInt8")],
[[None], np.dtype("uint16"), pd.Series([pd.NA], dtype="UInt16")],
[[None], np.dtype("uint32"), pd.Series([pd.NA], dtype="UInt32")],
[[None], np.dtype("uint64"), pd.Series([pd.NA], dtype="UInt64")],
[[None], np.dtype("float32"), pd.Series([pd.NA], dtype="Float32")],
[[None], np.dtype("float64"), pd.Series([pd.NA], dtype="Float64")],
],
)
def test_nullable_series(data, dtype, expected_series):
series = nullable_series(data, pd.DataFrame(), dtype)
pd.testing.assert_series_equal(series, expected_series)

0 comments on commit dd98a43

Please sign in to comment.