Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add and apply DictArray wrapper class and corresponding Protocol definitions #141

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions merlin/core/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import pyarrow.parquet as pq

from merlin.core.compat import HAS_GPU
from merlin.core.protocols import DataFrameLike, SeriesLike
from merlin.dag import DictArray

cp = None
cudf = None
Expand Down Expand Up @@ -269,13 +271,13 @@ def series_has_nulls(s):
return s.has_nulls


def list_val_dtype(ser: SeriesType) -> np.dtype:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing SeriesType and DataframeType (an ongoing pain with MyPy, which does not like dynamically defined types) with the protocols here doesn't functionally change much here, since the actual types don't prevent you from passing the wrong types into the function. The protocols (e.g. SeriesLike) can be checked at run-time via isinstance(obj, protocol) checks. Not sure if there's a way to enforce them automatically at runtime though.

def list_val_dtype(ser: SeriesLike) -> np.dtype:
"""
Return the dtype of the leaves from a list or nested list

Parameters
----------
ser : SeriesType
ser : SeriesLike
A series where the rows contain lists or nested lists

Returns
Expand Down Expand Up @@ -347,6 +349,11 @@ def concat_columns(args: list):
"""Dispatch function to concatenate DataFrames with axis=1"""
if len(args) == 1:
return args[0]
elif isinstance(args[0], DictArray):
result = DictArray({})
for arg in args:
result.update(arg)
return result
else:
_lib = cudf if HAS_GPU and isinstance(args[0], cudf.DataFrame) else pd
return _lib.concat(
Expand All @@ -356,12 +363,12 @@ def concat_columns(args: list):
return None


def read_parquet_dispatch(df: DataFrameType) -> Callable:
def read_parquet_dispatch(df: DataFrameLike) -> Callable:
"""Dispatch function for reading parquet files"""
return read_dispatch(df=df, fmt="parquet")


def read_dispatch(df: DataFrameType = None, cpu=None, collection=False, fmt="parquet") -> Callable:
def read_dispatch(df: DataFrameLike = None, cpu=None, collection=False, fmt="parquet") -> Callable:
"""Return the necessary read_parquet function to generate
data of a specified type.
"""
Expand All @@ -373,7 +380,7 @@ def read_dispatch(df: DataFrameType = None, cpu=None, collection=False, fmt="par
return getattr(_mod, _attr)


def parquet_writer_dispatch(df: DataFrameType, path=None, **kwargs):
def parquet_writer_dispatch(df: DataFrameLike, path=None, **kwargs):
"""Return the necessary ParquetWriter class to write
data of a specified type.

Expand Down Expand Up @@ -517,9 +524,9 @@ def detect_format(data):
"csv": ExtData.CSV,
}
if isinstance(data, list) and data:
file_type = mapping.get(str(data[0]).split(".")[-1], None)
file_type = mapping.get(str(data[0]).rsplit(".", maxsplit=1)[-1], None)
else:
file_type = mapping.get(str(data).split(".")[-1], None)
file_type = mapping.get(str(data).rsplit(".", maxsplit=1)[-1], None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(These changes are required to appease the linters, but are otherwise unrelated to this PR)

if file_type is None:
raise ValueError("Data format not recognized.")
return file_type
Expand Down
197 changes: 197 additions & 0 deletions merlin/core/protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# pylint:disable=too-many-public-methods
from typing import Protocol, runtime_checkable


@runtime_checkable
class DictLike(Protocol):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This protocol definition contains the maximal set of methods that are shared by both actual dictionaries and dataframes (which are conceptually dictionaries of columns.)

"""
These methods are present on plain Python dictionaries and also on DataFrames, which
are conceptually a dictionary of columns/series. Both Python dictionaries and DataFrames
therefore implement this Protocol, although neither sub-classes it. That means that
`isinstance(obj, DictLike)` will return `True` at runtime if obj is a dictionary, a DataFrame,
or any other type that implements the following methods.
"""

def __iter__(self):
return iter({})

def __len__(self):
return 0

def __getitem__(self, key):
...

def __setitem__(self, key, value):
...

def __delitem__(self, key):
...

def keys(self):
...

def items(self):
...

def values(self):
...

def update(self, other):
...

def copy(self):
...


@runtime_checkable
class SeriesLike(Protocol):
"""
These methods are defined by Pandas and cuDF series, and also by the array-wrapping
`Column` class defined in `merlin.dag`. If we want to provide column-level transformations
on data (e.g. to zero-copy share it across frameworks), the `Column` class would provide
a potential place to do that, and this Protocol would allow us to build abstractions that
make working with arrays and Series interchangeably possible.
"""

def values(self):
...

def dtype(self):
...

def __getitem__(self, index):
...

def __eq__(self, other):
...


@runtime_checkable
class Transformable(DictLike, Protocol):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to the methods dictionary methods that are shared by dataframes, there are a few methods from dataframes that we use so frequently that it's easier to wrap a dictionary in a class and add them to the wrapper class than it would be to refactor the whole code base to do without them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if the wrapper class is an implementation detail of the LocalExecutor, then it could potentially be encapsulated inside it so that regular dicts could be passed? Would that be a useful thing to do for the use-case in systems for example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible that we could do that, but it seems like a potential instance of the primitive obsession code smell and would deprive us a place to add methods with additional functionality related to e.g. transferring data across frameworks or treating multiple arrays as a single column. 🤔

"""
In addition to the dictionary methods that are shared by dataframes, there are a few
methods from dataframes that we use so frequently that it's easier to wrap a dictionary
in a class and add them to the wrapper class than it would be to refactor the whole code
base to do without them.
"""

@property
def columns(self):
...

def dtypes(self):
...

def __getitem__(self, index):
...


@runtime_checkable
class DataFrameLike(Transformable, Protocol):
"""
This is the maximal set of methods shared by both Pandas dataframes and cuDF dataframes
that aren't already part of the Transformable protocol. In theory, if there were another
dataframe library that implemented the methods in this Protocol (e.g. Polars), we could
use its dataframes in any place where we use the DataFrameLike type, but right now this
protocol is only intended to match Pandas and cuDF dataframes.
"""

def apply(self):
...

def describe(self):
...

def drop(self):
...

def explode(self):
...

def groupby(self):
...

def head(self):
...

def interpolate(self):
...

def join(self):
...

def max(self):
...

def mean(self):
...

def median(self):
...

def pipe(self):
...

def pivot(self):
...

def product(self):
...

def quantile(self):
...

def rename(self):
...

def replace(self):
...

def sample(self):
...

def shape(self):
...

def shift(self):
...

def std(self):
...

def sum(self):
...

def tail(self):
...

def to_dict(self):
...

def to_numpy(self):
...

def transpose(self):
...

def unstack(self):
...

def var(self):
...
1 change: 1 addition & 0 deletions merlin/dag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# flake8: noqa
from merlin.dag.base_operator import BaseOperator, Supports
from merlin.dag.dictarray import DictArray
from merlin.dag.graph import Graph
from merlin.dag.node import Node, iter_nodes, postorder_iter_nodes, preorder_iter_nodes
from merlin.dag.selector import ColumnSelector
33 changes: 33 additions & 0 deletions merlin/dag/base_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, List, Union

import merlin.dag
from merlin.core.protocols import Transformable
from merlin.dag.selector import ColumnSelector
from merlin.schema import ColumnSchema, Schema

Expand Down Expand Up @@ -131,7 +132,39 @@ def compute_output_schema(

return output_schema

def transform(
self, col_selector: ColumnSelector, transformable: Transformable
) -> Transformable:
"""Transform the dataframe by applying this operator to the set of input columns

Parameters
-----------
col_selector: ColumnSelector
The columns to apply this operator to
transformable: Transformable
A pandas or cudf dataframe that this operator will work on

Returns
-------
Transformable
Returns a transformed dataframe or dictarray for this operator
"""
return transformable

def column_mapping(self, col_selector):
"""
Compute which output columns depend on which input columns

Parameters
----------
col_selector : ColumnSelector
A selector containing a list of column names

Returns
-------
Dict[str, List[str]]
Mapping from output column names to list of the input columns they rely on
"""
column_mapping = {}
for col_name in col_selector.names:
column_mapping[col_name] = [col_name]
Expand Down
Loading