diff --git a/merlin/core/dispatch.py b/merlin/core/dispatch.py index 4d4262bcb..941a1320a 100644 --- a/merlin/core/dispatch.py +++ b/merlin/core/dispatch.py @@ -343,10 +343,8 @@ def is_list_dtype(ser): return pd.api.types.is_list_like(ser.values[0]) elif cudf and isinstance(ser, (cudf.Series, cudf.ListDtype)): return cudf_is_list_dtype(ser) - elif cudf and isinstance(ser, cp.ndarray): - return pd.api.types.is_list_like(ser[0]) - elif isinstance(ser, np.ndarray): - return pd.api.types.is_list_like(ser[0]) + elif isinstance(ser, np.ndarray) or (cp and isinstance(ser, cp.ndarray)): + return len(ser.shape) > 1 return pd.api.types.is_list_like(ser)