Skip to content

Commit

Permalink
Update instance checks in concat_columns to use cudf/pd classes.
Browse files Browse the repository at this point in the history
Adds test for concat_columns with list columns
  • Loading branch information
oliverholworthy committed Nov 10, 2022
1 parent c1ddc19 commit 6eeceb9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
13 changes: 4 additions & 9 deletions merlin/core/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import pyarrow.parquet as pq

from merlin.core.compat import HAS_GPU
from merlin.core.protocols import DataFrameLike, DictLike, SeriesLike
from merlin.core.protocols import DataFrameLike, SeriesLike


cp = None
cudf = None
Expand Down Expand Up @@ -350,23 +351,17 @@ def concat_columns(args: list):
"""Dispatch function to concatenate DataFrames with axis=1"""
if len(args) == 1:
return args[0]
elif isinstance(args[0], DataFrameLike):
elif isinstance(args[0], (cudf.DataFrame, pd.DataFrame)):
_lib = cudf if HAS_GPU and isinstance(args[0], cudf.DataFrame) else pd
return _lib.concat(
[a.reset_index(drop=True) for a in args],
axis=1,
)
elif isinstance(args[0], DictLike):
else:
result = type(args[0])()
for arg in args:
result.update(arg)
return result
else:
_lib = cudf if HAS_GPU and isinstance(args[0], cudf.DataFrame) else pd
return _lib.concat(
[a.reset_index(drop=True) for a in args],
axis=1,
)
return None


Expand Down
21 changes: 15 additions & 6 deletions tests/unit/core/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
import numpy as np
import pytest

from merlin.core.dispatch import HAS_GPU, is_list_dtype, list_val_dtype, make_df
from merlin.core.dispatch import HAS_GPU, is_list_dtype, list_val_dtype, make_df, concat_columns

if HAS_GPU:
_CPU = [True, False]
_DEVICES = ["cpu", "gpu"]
else:
_CPU = [True]
_DEVICES = ["cpu"]


@pytest.mark.parametrize("cpu", _CPU)
def test_list_dtypes(tmpdir, cpu):
df = make_df(device="cpu" if cpu else "gpu")
@pytest.mark.parametrize("device", _DEVICES)
def test_list_dtypes(tmpdir, device):
df = make_df(device=device)
df["vals"] = [
[[0, 1, 2], [3, 4], [5]],
]
Expand All @@ -35,3 +35,12 @@ def test_list_dtypes(tmpdir, cpu):

assert is_list_dtype(df["vals"])
assert list_val_dtype(df["vals"]) == np.dtype(np.int64)


@pytest.mark.parametrize("device", _DEVICES)
def test_concat_columns(device):
df1 = make_df({"a": [1, 2], "b": [[3], [4, 5]]}, device=device)
df2 = make_df({"c": [3, 4, 5]}, device=device)
data_frames = [df1, df2]
res = concat_columns(data_frames)
assert res.columns.to_list() == ["a", "b", "c"]

0 comments on commit 6eeceb9

Please sign in to comment.