diff --git a/merlin/core/dispatch.py b/merlin/core/dispatch.py index ef0811f5a..dd1af7851 100644 --- a/merlin/core/dispatch.py +++ b/merlin/core/dispatch.py @@ -84,8 +84,16 @@ def inner2(*args, **kwargs): # Define mapping between non-nullable, # and nullable types in Pandas _PD_NULLABLE_MAP = { + "float32": "Float32", + "float64": "Float64", + "int8": "Int8", + "int16": "Int16", "int32": "Int32", "int64": "Int64", + "uint8": "UInt8", + "uint16": "UInt16", + "uint32": "UInt32", + "uint64": "UInt64", } diff --git a/tests/unit/core/test_dispatch.py b/tests/unit/core/test_dispatch.py index 2928ed457..45a9d9d23 100644 --- a/tests/unit/core/test_dispatch.py +++ b/tests/unit/core/test_dispatch.py @@ -19,7 +19,13 @@ from merlin.core.compat import HAS_GPU from merlin.core.compat import cupy as cp -from merlin.core.dispatch import concat_columns, is_list_dtype, list_val_dtype, make_df +from merlin.core.dispatch import ( + concat_columns, + is_list_dtype, + list_val_dtype, + make_df, + nullable_series, +) if HAS_GPU: _DEVICES = ["cpu", "gpu"] @@ -59,3 +65,23 @@ def test_pandas_cupy_combo(): pd_df = pd.DataFrame(rand_cp_nd_arr.get())[0] mk_df = make_df(rand_cp_nd_arr)[0] assert all(pd_df.to_numpy() == mk_df.to_numpy()) + + +@pytest.mark.parametrize( + ["data", "dtype", "expected_series"], + [ + [[None], np.dtype("int8"), pd.Series([pd.NA], dtype="Int8")], + [[None], np.dtype("int16"), pd.Series([pd.NA], dtype="Int16")], + [[None], np.dtype("int32"), pd.Series([pd.NA], dtype="Int32")], + [[None], np.dtype("int64"), pd.Series([pd.NA], dtype="Int64")], + [[None], np.dtype("uint8"), pd.Series([pd.NA], dtype="UInt8")], + [[None], np.dtype("uint16"), pd.Series([pd.NA], dtype="UInt16")], + [[None], np.dtype("uint32"), pd.Series([pd.NA], dtype="UInt32")], + [[None], np.dtype("uint64"), pd.Series([pd.NA], dtype="UInt64")], + [[None], np.dtype("float32"), pd.Series([pd.NA], dtype="Float32")], + [[None], np.dtype("float64"), pd.Series([pd.NA], dtype="Float64")], + ], +) +def test_nullable_series(data, dtype, expected_series): + series = nullable_series(data, pd.DataFrame(), dtype) + pd.testing.assert_series_equal(series, expected_series)