-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add and apply DictArray
wrapper class and corresponding Protocol
definitions
#141
Changes from all commits
028140c
7ba04df
027432f
f3f7eb9
5f745f6
9d5dad2
49641f6
ff6d5aa
9ffd963
9fa9cc8
628f12e
2d9320e
96d071a
aa2fc48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,8 @@ | |
import pyarrow.parquet as pq | ||
|
||
from merlin.core.compat import HAS_GPU | ||
from merlin.core.protocols import DataFrameLike, SeriesLike | ||
from merlin.dag import DictArray | ||
|
||
cp = None | ||
cudf = None | ||
|
@@ -269,13 +271,13 @@ def series_has_nulls(s): | |
return s.has_nulls | ||
|
||
|
||
def list_val_dtype(ser: SeriesType) -> np.dtype: | ||
def list_val_dtype(ser: SeriesLike) -> np.dtype: | ||
""" | ||
Return the dtype of the leaves from a list or nested list | ||
|
||
Parameters | ||
---------- | ||
ser : SeriesType | ||
ser : SeriesLike | ||
A series where the rows contain lists or nested lists | ||
|
||
Returns | ||
|
@@ -347,6 +349,11 @@ def concat_columns(args: list): | |
"""Dispatch function to concatenate DataFrames with axis=1""" | ||
if len(args) == 1: | ||
return args[0] | ||
elif isinstance(args[0], DictArray): | ||
result = DictArray({}) | ||
for arg in args: | ||
result.update(arg) | ||
return result | ||
else: | ||
_lib = cudf if HAS_GPU and isinstance(args[0], cudf.DataFrame) else pd | ||
return _lib.concat( | ||
|
@@ -356,12 +363,12 @@ def concat_columns(args: list): | |
return None | ||
|
||
|
||
def read_parquet_dispatch(df: DataFrameType) -> Callable: | ||
def read_parquet_dispatch(df: DataFrameLike) -> Callable: | ||
"""Dispatch function for reading parquet files""" | ||
return read_dispatch(df=df, fmt="parquet") | ||
|
||
|
||
def read_dispatch(df: DataFrameType = None, cpu=None, collection=False, fmt="parquet") -> Callable: | ||
def read_dispatch(df: DataFrameLike = None, cpu=None, collection=False, fmt="parquet") -> Callable: | ||
"""Return the necessary read_parquet function to generate | ||
data of a specified type. | ||
""" | ||
|
@@ -373,7 +380,7 @@ def read_dispatch(df: DataFrameType = None, cpu=None, collection=False, fmt="par | |
return getattr(_mod, _attr) | ||
|
||
|
||
def parquet_writer_dispatch(df: DataFrameType, path=None, **kwargs): | ||
def parquet_writer_dispatch(df: DataFrameLike, path=None, **kwargs): | ||
"""Return the necessary ParquetWriter class to write | ||
data of a specified type. | ||
|
||
|
@@ -517,9 +524,9 @@ def detect_format(data): | |
"csv": ExtData.CSV, | ||
} | ||
if isinstance(data, list) and data: | ||
file_type = mapping.get(str(data[0]).split(".")[-1], None) | ||
file_type = mapping.get(str(data[0]).rsplit(".", maxsplit=1)[-1], None) | ||
else: | ||
file_type = mapping.get(str(data).split(".")[-1], None) | ||
file_type = mapping.get(str(data).rsplit(".", maxsplit=1)[-1], None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (These changes are required to appease the linters, but are otherwise unrelated to this PR) |
||
if file_type is None: | ||
raise ValueError("Data format not recognized.") | ||
return file_type | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
# pylint:disable=too-many-public-methods | ||
from typing import Protocol, runtime_checkable | ||
|
||
|
||
@runtime_checkable | ||
class DictLike(Protocol): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This protocol definition contains the maximal set of methods that are shared by both actual dictionaries and dataframes (which are conceptually dictionaries of columns.) |
||
""" | ||
These methods are present on plain Python dictionaries and also on DataFrames, which | ||
are conceptually a dictionary of columns/series. Both Python dictionaries and DataFrames | ||
therefore implement this Protocol, although neither sub-classes it. That means that | ||
`isinstance(obj, DictLike)` will return `True` at runtime if obj is a dictionary, a DataFrame, | ||
or any other type that implements the following methods. | ||
""" | ||
|
||
def __iter__(self): | ||
return iter({}) | ||
|
||
def __len__(self): | ||
return 0 | ||
|
||
def __getitem__(self, key): | ||
... | ||
|
||
def __setitem__(self, key, value): | ||
... | ||
|
||
def __delitem__(self, key): | ||
... | ||
|
||
def keys(self): | ||
... | ||
|
||
def items(self): | ||
... | ||
|
||
def values(self): | ||
... | ||
|
||
def update(self, other): | ||
... | ||
|
||
def copy(self): | ||
... | ||
|
||
|
||
@runtime_checkable | ||
class SeriesLike(Protocol): | ||
""" | ||
These methods are defined by Pandas and cuDF series, and also by the array-wrapping | ||
`Column` class defined in `merlin.dag`. If we want to provide column-level transformations | ||
on data (e.g. to zero-copy share it across frameworks), the `Column` class would provide | ||
a potential place to do that, and this Protocol would allow us to build abstractions that | ||
make working with arrays and Series interchangeably possible. | ||
""" | ||
|
||
def values(self): | ||
... | ||
|
||
def dtype(self): | ||
... | ||
|
||
def __getitem__(self, index): | ||
... | ||
|
||
def __eq__(self, other): | ||
... | ||
|
||
|
||
@runtime_checkable | ||
class Transformable(DictLike, Protocol): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In addition to the methods dictionary methods that are shared by dataframes, there are a few methods from dataframes that we use so frequently that it's easier to wrap a dictionary in a class and add them to the wrapper class than it would be to refactor the whole code base to do without them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if the wrapper class is an implementation detail of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's possible that we could do that, but it seems like a potential instance of the primitive obsession code smell and would deprive us a place to add methods with additional functionality related to e.g. transferring data across frameworks or treating multiple arrays as a single column. 🤔 |
||
""" | ||
In addition to the dictionary methods that are shared by dataframes, there are a few | ||
methods from dataframes that we use so frequently that it's easier to wrap a dictionary | ||
in a class and add them to the wrapper class than it would be to refactor the whole code | ||
base to do without them. | ||
""" | ||
|
||
@property | ||
def columns(self): | ||
... | ||
|
||
def dtypes(self): | ||
... | ||
|
||
def __getitem__(self, index): | ||
... | ||
|
||
|
||
@runtime_checkable | ||
class DataFrameLike(Transformable, Protocol): | ||
karlhigley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
This is the maximal set of methods shared by both Pandas dataframes and cuDF dataframes | ||
that aren't already part of the Transformable protocol. In theory, if there were another | ||
dataframe library that implemented the methods in this Protocol (e.g. Polars), we could | ||
use its dataframes in any place where we use the DataFrameLike type, but right now this | ||
protocol is only intended to match Pandas and cuDF dataframes. | ||
""" | ||
|
||
def apply(self): | ||
... | ||
|
||
def describe(self): | ||
... | ||
|
||
def drop(self): | ||
... | ||
|
||
def explode(self): | ||
... | ||
|
||
def groupby(self): | ||
... | ||
|
||
def head(self): | ||
... | ||
|
||
def interpolate(self): | ||
... | ||
|
||
def join(self): | ||
... | ||
|
||
def max(self): | ||
... | ||
|
||
def mean(self): | ||
... | ||
|
||
def median(self): | ||
... | ||
|
||
def pipe(self): | ||
... | ||
|
||
def pivot(self): | ||
... | ||
|
||
def product(self): | ||
... | ||
|
||
def quantile(self): | ||
... | ||
|
||
def rename(self): | ||
... | ||
|
||
def replace(self): | ||
... | ||
|
||
def sample(self): | ||
... | ||
|
||
def shape(self): | ||
... | ||
|
||
def shift(self): | ||
... | ||
|
||
def std(self): | ||
... | ||
|
||
def sum(self): | ||
... | ||
|
||
def tail(self): | ||
... | ||
|
||
def to_dict(self): | ||
... | ||
|
||
def to_numpy(self): | ||
... | ||
|
||
def transpose(self): | ||
... | ||
|
||
def unstack(self): | ||
... | ||
|
||
def var(self): | ||
... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replacing
SeriesType
andDataframeType
(an ongoing pain with MyPy, which does not like dynamically defined types) with the protocols here doesn't functionally change much here, since the actual types don't prevent you from passing the wrong types into the function. The protocols (e.g.SeriesLike
) can be checked at run-time viaisinstance(obj, protocol)
checks. Not sure if there's a way to enforce them automatically at runtime though.