diff --git a/merlin/core/dispatch.py b/merlin/core/dispatch.py index e746a94f6..4d4262bcb 100644 --- a/merlin/core/dispatch.py +++ b/merlin/core/dispatch.py @@ -374,9 +374,16 @@ def flatten_list_column_values(s): return pd.Series(itertools.chain(*s)) elif cudf and isinstance(s, cudf.Series): return s.list.leaves + elif cp and isinstance(s, cp.ndarray): + return s.flatten() + elif isinstance(s, np.ndarray): + return s.flatten() else: raise ValueError( - "Unsupported series type: " f"{type(s)}" " Expected either a pandas or cudf Series." + "Unsupported series type: " + f"{type(s)} " + "Expected either a pandas or cuDF Series. " + "Or a NumPy or CuPy array" )