Skip to content

Commit

Permalink
Remove explicit DictArray reference from merlin.core.dispatch (#163)
Browse files Browse the repository at this point in the history
* Remove explicit `DictArray` reference from `merlin.core.dispatch`

* Make `DictArray` values arg optional
  • Loading branch information
karlhigley authored Nov 9, 2022
1 parent 64755ba commit c1ddc19
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
13 changes: 9 additions & 4 deletions merlin/core/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
import pyarrow.parquet as pq

from merlin.core.compat import HAS_GPU
from merlin.core.protocols import DataFrameLike, SeriesLike
from merlin.dag import DictArray
from merlin.core.protocols import DataFrameLike, DictLike, SeriesLike

cp = None
cudf = None
Expand Down Expand Up @@ -351,8 +350,14 @@ def concat_columns(args: list):
"""Dispatch function to concatenate DataFrames with axis=1"""
if len(args) == 1:
return args[0]
elif isinstance(args[0], DictArray):
result = DictArray({})
elif isinstance(args[0], DataFrameLike):
_lib = cudf if HAS_GPU and isinstance(args[0], cudf.DataFrame) else pd
return _lib.concat(
[a.reset_index(drop=True) for a in args],
axis=1,
)
elif isinstance(args[0], DictLike):
result = type(args[0])()
for arg in args:
result.update(arg)
return result
Expand Down
4 changes: 3 additions & 1 deletion merlin/dag/dictarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ class DictArray(Transformable):
A simple dataframe-like wrapper around a dictionary of values
"""

def __init__(self, values: Dict, dtypes: Optional[Dict] = None):
def __init__(self, values: Optional[Dict] = None, dtypes: Optional[Dict] = None):
super().__init__()

values = values or {}

array_values = {}
for key, value in values.items():
array_values[key] = np.array(value) if isinstance(value, list) else value
Expand Down

0 comments on commit c1ddc19

Please sign in to comment.