diff --git a/merlin/systems/dag/ensemble.py b/merlin/systems/dag/ensemble.py index b3adaccce..8ce34e828 100644 --- a/merlin/systems/dag/ensemble.py +++ b/merlin/systems/dag/ensemble.py @@ -23,6 +23,7 @@ from merlin.core.protocols import Transformable from merlin.dag import Graph +from merlin.systems.dag.runtimes.base_runtime import Runtime from merlin.systems.dag.runtimes.triton import TritonExecutorRuntime @@ -74,7 +75,7 @@ def transform(self, transformable: Transformable, runtime=None): Transformable transformed data """ - runtime = runtime or TritonExecutorRuntime() + runtime = runtime or Runtime() return runtime.transform(self.graph, transformable) def save(self, path): diff --git a/merlin/systems/dag/runtimes/__init__.py b/merlin/systems/dag/runtimes/__init__.py index a6e44cb75..880882724 100644 --- a/merlin/systems/dag/runtimes/__init__.py +++ b/merlin/systems/dag/runtimes/__init__.py @@ -13,6 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -# flake8: noqa +# flake8: noqa from .base_runtime import Runtime diff --git a/merlin/systems/dag/runtimes/base_runtime.py b/merlin/systems/dag/runtimes/base_runtime.py index c68eeaadb..f393350fc 100644 --- a/merlin/systems/dag/runtimes/base_runtime.py +++ b/merlin/systems/dag/runtimes/base_runtime.py @@ -14,12 +14,14 @@ # limitations under the License. # from merlin.core.protocols import Transformable -from merlin.dag import Graph +from merlin.dag import Graph, postorder_iter_nodes from merlin.dag.executors import LocalExecutor +from merlin.systems.dag.runtimes.op_table import OpTable class Runtime: - """A Systems Graph Runtime. + """ + A runtime for Merlin DAGs that supports using custom implementations of existing operators This class can be used as a base class for custom runtimes. """ @@ -33,9 +35,32 @@ def __init__(self, executor=None): The Graph Executor to use to use for the transform, by default None """ self.executor = executor or LocalExecutor() - self.op_table = {} + self.op_table = OpTable() + + def convert(self, graph: Graph): + """ + Replace the operators in the supplied graph with ops from this runtime's op table + + Parameters + ---------- + graph : Graph + Graph of nodes container operator chains for data manipulation. - def transform(self, graph: Graph, transformable: Transformable): + Returns + ------- + Graph + Copy of the graph with operators converted to this runtime's versions + """ + if not self.op_table.empty: + nodes = list(postorder_iter_nodes(graph.output_node)) + + for node in nodes: + if self.op_table.has_impl(node.op): + node.op = self.op_table.replace(node.op) + + return graph + + def transform(self, graph: Graph, transformable: Transformable, convert=True): """Run the graph with the input data. Parameters @@ -44,12 +69,17 @@ def transform(self, graph: Graph, transformable: Transformable): Graph of nodes container operator chains for data manipulation. transformable : Transformable Input data to transform in graph. + convert: bool + If True, converts the operators in the graph to this runtime's versions Returns ------- Transformable Input data after it has been transformed via graph. """ + if convert: + graph = self.convert(graph) + return self.executor.transform(transformable, [graph.output_node]) def export(self): diff --git a/merlin/systems/dag/runtimes/nvtabular/__init__.py b/merlin/systems/dag/runtimes/nvtabular/__init__.py new file mode 100644 index 000000000..68b2015fe --- /dev/null +++ b/merlin/systems/dag/runtimes/nvtabular/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/merlin/systems/dag/runtimes/nvtabular/executor.py b/merlin/systems/dag/runtimes/nvtabular/executor.py new file mode 100644 index 000000000..1a4f77d94 --- /dev/null +++ b/merlin/systems/dag/runtimes/nvtabular/executor.py @@ -0,0 +1,305 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +import itertools +import logging + +from merlin.core.compat import cudf +from merlin.core.compat import cupy as cp +from merlin.core.compat import numpy as np +from merlin.core.compat import pandas +from merlin.core.dispatch import build_cudf_list_column, concat_columns, is_list_dtype +from merlin.dag import Graph, Node, Supports +from merlin.dag.executors import LocalExecutor +from merlin.table import CupyColumn, NumpyColumn, TensorTable + +LOG = logging.getLogger("merlin-systems") + + +class NVTabularServingExecutor(LocalExecutor): + """ + An executor for running Merlin operator DAGs locally + """ + + def __init__(self, device: str): + self.device = device + + def transform( + self, + transformable, + graph, + output_dtypes=None, + additional_columns=None, + capture_dtypes=False, + strict=False, + output_format=Supports.CPU_DICT_ARRAY, + ): + """ + Transforms a single dataframe (possibly a partition of a Dask Dataframe) + by applying the operators from a collection of Nodes + """ + nodes = [] + if isinstance(graph, Graph): + nodes.append(graph.output_node) + elif isinstance(graph, Node): + nodes.append(graph) + elif isinstance(graph, list): + nodes = graph + else: + raise TypeError( + f"LocalExecutor detected unsupported type of input for graph: {type(graph)}." + " `graph` argument must be either a `Graph` object (preferred)" + " or a list of `Node` objects (deprecated, but supported for backward " + " compatibility.)" + ) + + output_data = None + + for node in nodes: + transformed_data = self._execute_node(node, transformable) + output_data = self._combine_node_outputs(node, transformed_data, output_data) + + if additional_columns: + output_data = concat_columns( + [output_data, transformable[_get_unique(additional_columns)]] + ) + + format_ = _data_format(output_data) + if format_ != output_format: + output_data = _convert_format(output_data, output_format) + + return output_data + + def _execute_node(self, workflow_node, input_tensors, capture_dtypes=False, strict=False): + upstream_outputs = self._run_upstream_transforms(workflow_node, input_tensors) + upstream_outputs = self._merge_addl_root_columns( + workflow_node, input_tensors, upstream_outputs + ) + tensors = self._standardize_formats(workflow_node, upstream_outputs) + + transform_input = _concat_tensors(tensors) + # TODO: In order to replace the line above with the line below, we first have to replace + # dictionaries with TensorTables + # transform_input = self._merge_upstream_columns(tensors, merge_fn=_concat_tensors) + transform_output = self._run_node_transform(workflow_node, transform_input) + + return transform_output + + def _merge_addl_root_columns(self, workflow_node, input_tensors, upstream_outputs): + if workflow_node.selector: + selector_columns = workflow_node.selector.names + to_remove = [] + for upstream_tensors in upstream_outputs or []: + for col in selector_columns: + if col in upstream_tensors: + to_remove.append(col) + for col in set(to_remove): + selector_columns.remove(col) + + if selector_columns: + selected_tensors = {c: input_tensors[c] for c in selector_columns} + upstream_outputs.append(selected_tensors) + + return upstream_outputs + + def _standardize_formats(self, workflow_node, node_input_data): + # Get the supported formats + op = workflow_node.op + if op and hasattr(op, "inference_initialize"): + supported_formats = _maybe_mask_cpu_only(op.supports, self.device) + else: + supported_formats = Supports.CPU_DICT_ARRAY + + # Convert the first thing into a supported format + tensors = _convert_format(node_input_data[0], supported_formats) + target_format = _data_format(tensors) + + # Convert the whole list into the same format + formatted_tensors = [] + for upstream_tensors in node_input_data: + upstream_tensors = _convert_format(upstream_tensors, target_format) + formatted_tensors.append(upstream_tensors) + + return formatted_tensors + + +def _concat_tensors(tensors): + format_ = _data_format(tensors[0]) + + if format_ & (Supports.GPU_DATAFRAME | Supports.CPU_DATAFRAME): + return concat_columns(tensors) + else: + output = tensors[0] + for tensor in tensors[1:]: + output.update(tensor) + return output + + +def _maybe_mask_cpu_only(supported, device): + # if we're running on the CPU only, mask off support for GPU data formats + if device == "CPU": + supported = functools.reduce( + lambda a, b: a | b, + (v for v in list(Supports) if v & supported and "CPU" in str(v)), + ) + + return supported + + +def _get_unique(cols): + # Need to preserve order in unique-column list + return list({x: x for x in cols}.keys()) + + +def _data_format(transformable): + data = TensorTable(transformable) if isinstance(transformable, dict) else transformable + + if cudf and isinstance(data, cudf.DataFrame): + return Supports.GPU_DATAFRAME + elif pandas and isinstance(data, pandas.DataFrame): + return Supports.CPU_DATAFRAME + elif data.column_type is CupyColumn: + return Supports.GPU_DICT_ARRAY + elif data.column_type is NumpyColumn: + return Supports.CPU_DICT_ARRAY + else: + if isinstance(data, TensorTable): + raise TypeError(f"Unknown type: {data.column_type}") + else: + raise TypeError(f"Unknown type: {type(data)}") + + +def _convert_format(tensors, target_format): + """ + Converts data to one of the formats specified in 'target_format' + + This allows us to convert data to/from dataframe representations for operators that + only support certain reprentations + """ + format_ = _data_format(tensors) + + # this is all much more difficult because of multihot columns, which don't have + # great representations in dicts of cpu/gpu arrays. we're representing multihots + # as tuples of (values, offsets) tensors in this case - but have to do work at + # each step in terms of converting. + if format_ & target_format: + return tensors + + elif target_format & Supports.GPU_DICT_ARRAY: + if format_ == Supports.CPU_DICT_ARRAY: + return _convert_array(tensors, cp.array) + elif format_ == Supports.CPU_DATAFRAME: + return _pandas_to_array(tensors, False) + elif format_ == Supports.GPU_DATAFRAME: + return _cudf_to_array(tensors, False) + + elif target_format & Supports.CPU_DICT_ARRAY: + if format_ == Supports.GPU_DICT_ARRAY: + return _convert_array(tensors, cp.asnumpy) + elif format_ == Supports.CPU_DATAFRAME: + return _pandas_to_array(tensors, True) + elif format_ == Supports.GPU_DATAFRAME: + return _cudf_to_array(tensors, True) + + elif cudf and target_format & Supports.GPU_DATAFRAME: + if format_ == Supports.CPU_DATAFRAME: + return cudf.DataFrame(tensors) + return _array_to_cudf(tensors) + + elif target_format & Supports.CPU_DATAFRAME: + if format_ == Supports.GPU_DATAFRAME: + return tensors.to_pandas() + elif format_ == Supports.CPU_DICT_ARRAY: + return _array_to_pandas(tensors) + elif format_ == Supports.GPU_DICT_ARRAY: + return _array_to_pandas(_convert_array(tensors, cp.asnumpy)) + + raise ValueError("unsupported target for converting tensors", target_format) + + +def _convert_array(tensors, converter): + output = {} + for name, tensor in tensors.items(): + if isinstance(tensor, tuple): + output[name] = tuple(converter(t) for t in tensor) + else: + output[name] = converter(tensor) + return output + + +def _array_to_pandas(tensors): + output = pandas.DataFrame() + for name, tensor in tensors.items(): + if isinstance(tensor, tuple): + values, offsets = tensor + output[name] = [values[offsets[i] : offsets[i + 1]] for i in range(len(offsets) - 1)] + else: + output[name] = tensor + return output + + +def _array_to_cudf(tensors): + output = cudf.DataFrame() + for name, tensor in tensors.items(): + if isinstance(tensor, tuple): + output[name] = build_cudf_list_column(tensor[0], tensor[1].astype("int32")) + else: + output[name] = tensor + return output + + +def _pandas_to_array(df, cpu=True): + array_type = np.array if cpu else cp.array + + output = {} + for name in df.columns: + col = df[name] + if pandas.api.types.is_list_like(col.values[0]): + values = array_type(list(itertools.chain(*col))) + row_lengths = col.map(len) + if all(row_lengths == row_lengths[0]): + output[name] = values.reshape((-1, row_lengths[0])) + else: + offsets = pandas.Series([0]).append(row_lengths.cumsum()).values + if not cpu: + offsets = cp.array(offsets) + output[name] = (values, offsets) + else: + values = col.values + if not cpu: + values = cp.array(values) + output[name] = values + + return output + + +def _cudf_to_array(df, cpu=True): + output = {} + for name in df.columns: + col = df[name] + if is_list_dtype(col.dtype): + values = col.list.leaves.values_host if cpu else col.list.leaves.values + offsets = col._column.offsets.values_host if cpu else col._column.offsets.values + + row_lengths = offsets[1:] - offsets[:-1] + if all(row_lengths == row_lengths[0]): + output[name] = values.reshape((-1, row_lengths[0])) + else: + output[name] = (values, offsets) + else: + output[name] = col.values_host if cpu else col.values + + return output diff --git a/merlin/systems/dag/runtimes/nvtabular/runtime.py b/merlin/systems/dag/runtimes/nvtabular/runtime.py new file mode 100644 index 000000000..a379bc106 --- /dev/null +++ b/merlin/systems/dag/runtimes/nvtabular/runtime.py @@ -0,0 +1,40 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + +import nvtabular_cpp + +from merlin.systems.dag.runtimes import Runtime +from merlin.systems.dag.runtimes.nvtabular.executor import NVTabularServingExecutor +from merlin.systems.dag.runtimes.op_table import OpTable +from nvtabular.ops import Categorify, FillMissing + +LOG = logging.getLogger("merlin-systems") + + +NVTABULAR_OP_TABLE = OpTable() +NVTABULAR_OP_TABLE.register( + Categorify, nvtabular_cpp.inference.CategorifyTransform, lambda op: op.encode_type != "combo" +) +NVTABULAR_OP_TABLE.register( + FillMissing, nvtabular_cpp.inference.FillTransform, lambda op: not op.add_binary_cols +) + + +class NVTabularServingRuntime(Runtime): + def __init__(self, device: str): + super().__init__(executor=NVTabularServingExecutor(device)) + self.op_table = NVTABULAR_OP_TABLE diff --git a/merlin/systems/dag/runtimes/op_table.py b/merlin/systems/dag/runtimes/op_table.py new file mode 100644 index 000000000..2ce164242 --- /dev/null +++ b/merlin/systems/dag/runtimes/op_table.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +class OpTable: + """ + A table of alternate implementation of Merlin DAG ops to be used in a particular Runtime + """ + + def __init__(self): + self.ops = {} + self.conditions = {} + + @property + def empty(self): + return len(self.ops) == 0 + + def register(self, op, op_impl, condition=None): + """ + Register an alternate implementation for an operator + + Parameters + ---------- + op : Operator + The operator to replace + op_impl : Operator + The alternate implementation to replace it with + condition : Callable, optional + The boolean condition under which to do the replacement, by default None + """ + self.ops[op] = op_impl + if condition: + self.conditions[op] = condition + + def has_impl(self, op): + """ + Check if this OpTable has an alternate implementation for a particular operator + + Parameters + ---------- + op : Operator + The operator to check for alternate implementations of + + Returns + ------- + bool + True if there is a registered implementation and either the condition is True + or there's no registered conditional for the op's type + """ + op_type = type(op) + return op_type in self.ops and ( + op_type not in self.conditions or self.conditions[op_type](op) + ) + + def replace(self, op): + """ + Creates an operator that replaces `op` in a Merlin DAG + + Parameters + ---------- + op : Operator + The operator to build a replacement for + + Returns + ------- + Operator + Replacement operator for the original op + """ + op_type = type(op) + return self.ops[op_type](op) diff --git a/merlin/systems/dag/runtimes/triton/runtime.py b/merlin/systems/dag/runtimes/triton/runtime.py index bbb2de1b5..c20a24c1f 100644 --- a/merlin/systems/dag/runtimes/triton/runtime.py +++ b/merlin/systems/dag/runtimes/triton/runtime.py @@ -33,6 +33,7 @@ ) from merlin.systems.dag.ops.workflow import TransformWorkflow from merlin.systems.dag.runtimes import Runtime +from merlin.systems.dag.runtimes.op_table import OpTable from merlin.systems.dag.runtimes.triton.ops.operator import TritonOperator, add_model_param from merlin.systems.dag.runtimes.triton.ops.workflow import TransformWorkflowTriton @@ -53,26 +54,26 @@ ... -TRITON_OP_TABLE = {} -TRITON_OP_TABLE[TransformWorkflow] = TransformWorkflowTriton +TRITON_OP_TABLE = OpTable() +TRITON_OP_TABLE.register(TransformWorkflow, TransformWorkflowTriton) if cuml_ensemble or lightgbm or sklearn_ensemble or treelite_sklearn or xgboost: from merlin.systems.dag.ops.fil import PredictForest from merlin.systems.dag.runtimes.triton.ops.fil import PredictForestTriton - TRITON_OP_TABLE[PredictForest] = PredictForestTriton + TRITON_OP_TABLE.register(PredictForest, PredictForestTriton) if tensorflow: from merlin.systems.dag.ops.tensorflow import PredictTensorflow from merlin.systems.dag.runtimes.triton.ops.tensorflow import PredictTensorflowTriton - TRITON_OP_TABLE[PredictTensorflow] = PredictTensorflowTriton + TRITON_OP_TABLE.register(PredictTensorflow, PredictTensorflowTriton) if torch: from merlin.systems.dag.ops.pytorch import PredictPyTorch from merlin.systems.dag.runtimes.triton.ops.pytorch import PredictPyTorchTriton - TRITON_OP_TABLE[PredictPyTorch] = PredictPyTorchTriton + TRITON_OP_TABLE.register(PredictPyTorch, PredictPyTorchTriton) class TritonExecutorRuntime(Runtime): @@ -112,13 +113,9 @@ def export( Tuple of ensemble config and list of non-python backend model configs """ triton_model_name = name or "executor_model" + ensemble.graph = self.convert(ensemble.graph) nodes = list(postorder_iter_nodes(ensemble.graph.output_node)) - - for node in nodes: - if type(node.op) in self.op_table: - node.op = self.op_table[type(node.op)](node.op) - node_id_table, _ = _create_node_table(nodes) # Path were extra files can be optionally saved by operators diff --git a/merlin/systems/triton/conversions.py b/merlin/systems/triton/conversions.py index e62bf7b53..6e7660f63 100644 --- a/merlin/systems/triton/conversions.py +++ b/merlin/systems/triton/conversions.py @@ -24,15 +24,8 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import itertools - -import numpy as np -import pandas as pd - -from merlin.core.compat import cudf from merlin.core.compat import cupy as cp -from merlin.core.dispatch import build_cudf_list_column, is_list_dtype -from merlin.dag import Supports +from merlin.core.compat import numpy as np from merlin.systems.dag.ops.compat import pb_utils from merlin.table import TensorTable @@ -209,123 +202,3 @@ def _to_array_lib(triton_tensor): "Can't convert Triton GPU tensors to CuPy tensors without CuPy available. " "Is it installed?" ) - - -def convert_format(tensors, kind, target_kind): - """Converts data from format 'kind' to one of the formats specified in 'target_kind' - This allows us to convert data to/from dataframe representations for operators that - only support certain reprentations - """ - - # this is all much more difficult because of multihot columns, which don't have - # great representations in dicts of cpu/gpu arrays. we're representing multihots - # as tuples of (values, offsets) tensors in this case - but have to do work at - # each step in terms of converting. - if kind & target_kind: - return tensors, kind - - elif target_kind & Supports.GPU_DICT_ARRAY: - if kind == Supports.CPU_DICT_ARRAY: - return _convert_array(tensors, cp.array), Supports.GPU_DICT_ARRAY - elif kind == Supports.CPU_DATAFRAME: - return _pandas_to_array(tensors, False), Supports.GPU_DICT_ARRAY - elif kind == Supports.GPU_DATAFRAME: - return _cudf_to_array(tensors, False), Supports.GPU_DICT_ARRAY - - elif target_kind & Supports.CPU_DICT_ARRAY: - if kind == Supports.GPU_DICT_ARRAY: - return _convert_array(tensors, cp.asnumpy), Supports.CPU_DICT_ARRAY - elif kind == Supports.CPU_DATAFRAME: - return _pandas_to_array(tensors, True), Supports.CPU_DICT_ARRAY - elif kind == Supports.GPU_DATAFRAME: - return _cudf_to_array(tensors, True), Supports.CPU_DICT_ARRAY - - elif cudf and target_kind & Supports.GPU_DATAFRAME: - if kind == Supports.CPU_DATAFRAME: - return cudf.DataFrame(tensors), Supports.GPU_DATAFRAME - return _array_to_cudf(tensors), Supports.GPU_DATAFRAME - - elif target_kind & Supports.CPU_DATAFRAME: - if kind == Supports.GPU_DATAFRAME: - return tensors.to_pandas(), Supports.CPU_DATAFRAME - elif kind == Supports.CPU_DICT_ARRAY: - return _array_to_pandas(tensors), Supports.CPU_DATAFRAME - elif kind == Supports.GPU_DICT_ARRAY: - return _array_to_pandas(_convert_array(tensors, cp.asnumpy)), Supports.CPU_DATAFRAME - - raise ValueError("unsupported target for converting tensors", target_kind) - - -def _convert_array(tensors, converter): - output = {} - for name, tensor in tensors.items(): - if isinstance(tensor, tuple): - output[name] = tuple(converter(t) for t in tensor) - else: - output[name] = converter(tensor) - return output - - -def _array_to_pandas(tensors): - output = pd.DataFrame() - for name, tensor in tensors.items(): - if isinstance(tensor, tuple): - values, offsets = tensor - output[name] = [values[offsets[i] : offsets[i + 1]] for i in range(len(offsets) - 1)] - else: - output[name] = tensor - return output - - -def _array_to_cudf(tensors): - output = cudf.DataFrame() - for name, tensor in tensors.items(): - if isinstance(tensor, tuple): - output[name] = build_cudf_list_column(tensor[0], tensor[1].astype("int32")) - else: - output[name] = tensor - return output - - -def _pandas_to_array(df, cpu=True): - array_type = np.array if cpu else cp.array - - output = {} - for name in df.columns: - col = df[name] - if pd.api.types.is_list_like(col.values[0]): - values = array_type(list(itertools.chain(*col))) - row_lengths = col.map(len) - if all(row_lengths == row_lengths[0]): - output[name] = values.reshape((-1, row_lengths[0])) - else: - offsets = pd.Series([0]).append(row_lengths.cumsum()).values - if not cpu: - offsets = cp.array(offsets) - output[name] = (values, offsets) - else: - values = col.values - if not cpu: - values = cp.array(values) - output[name] = values - - return output - - -def _cudf_to_array(df, cpu=True): - output = {} - for name in df.columns: - col = df[name] - if is_list_dtype(col.dtype): - values = col.list.leaves.values_host if cpu else col.list.leaves.values - offsets = col._column.offsets.values_host if cpu else col._column.offsets.values - - row_lengths = offsets[1:] - offsets[:-1] - if all(row_lengths == row_lengths[0]): - output[name] = values.reshape((-1, row_lengths[0])) - else: - output[name] = (values, offsets) - else: - output[name] = col.values_host if cpu else col.values - - return output diff --git a/merlin/systems/triton/models/workflow_model.py b/merlin/systems/triton/models/workflow_model.py index 2af192e23..a6e3af87e 100644 --- a/merlin/systems/triton/models/workflow_model.py +++ b/merlin/systems/triton/models/workflow_model.py @@ -25,15 +25,21 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import json +import logging import pathlib -import triton_python_backend_utils as pb_utils - import nvtabular -from merlin.core.dispatch import is_list_dtype -from merlin.systems.triton import _convert_tensor +from merlin.dag import ColumnSelector +from merlin.schema import Tags +from merlin.systems.dag.runtimes.nvtabular.runtime import NVTabularServingRuntime +from merlin.systems.triton.conversions import ( + tensor_table_to_triton_response, + triton_request_to_tensor_table, +) from merlin.systems.triton.utils import triton_error_handling, triton_multi_request -from merlin.systems.workflow.base import WorkflowRunner +from merlin.table import TensorTable + +LOG = logging.getLogger("merlin-systems") class TritonPythonModel: @@ -64,33 +70,31 @@ def initialize(self, args): repository_path = repository_path.parent.parent workflow_path = repository_path / str(args["model_version"]) / "workflow" + model_device = args["model_instance_kind"] + self.runtime = NVTabularServingRuntime(model_device) # Workflow instantiation - self.workflow = nvtabular.Workflow.load(str(workflow_path)) + workflow = nvtabular.Workflow.load(str(workflow_path)) + workflow.graph = self.runtime.convert(workflow.graph) + self.workflow = workflow # Config loading and parsing - self.model_config = json.loads(args["model_config"]) + model_config = json.loads(args["model_config"]) - # Dtype parsing - input_dtypes = self.workflow.input_dtypes.items() - self.input_dtypes, self.input_multihots = _parse_input_dtypes(input_dtypes) + mc_cats, mc_conts = _parse_mc_features(model_config) + schema_cats, schema_conts = _parse_schema_features(self.workflow.output_schema) - self.output_dtypes = {} - for col_name, col_schema in self.workflow.output_schema.column_schemas.items(): - if col_schema.is_list and col_schema.is_ragged: - self._set_output_dtype(col_name + "__offsets") - self._set_output_dtype(col_name + "__values") - else: - self._set_output_dtype(col_name) + self.cats = mc_cats or schema_cats + self.conts = mc_conts or schema_conts - self.runner = WorkflowRunner( - self.workflow, self.output_dtypes, self.model_config, model_device - ) + missing_cols = set(self.cats + self.conts) - set(self.workflow.output_schema.column_names) - def _set_output_dtype(self, name): - conf = pb_utils.get_output_config_by_name(self.model_config, name) - self.output_dtypes[name] = pb_utils.triton_string_to_numpy(conf["data_type"]) + if missing_cols: + raise ValueError( + "The following requested columns were not found in the workflow's output: " + f"{missing_cols}" + ) @triton_multi_request @triton_error_handling @@ -98,28 +102,37 @@ def execute(self, request): """Transforms the input batches by running through a NVTabular workflow.transform function. """ - # transform the triton tensors to a dict of name:numpy tensor - input_tensors = { - name: _convert_tensor(pb_utils.get_input_tensor_by_name(request, name)) - for name in self.input_dtypes - } - - # multihots are represented as a tuple of (values, offsets) - for name, dtype in self.input_multihots.items(): - values = _convert_tensor(pb_utils.get_input_tensor_by_name(request, name + "__values")) - offsets = _convert_tensor( - pb_utils.get_input_tensor_by_name(request, name + "__offsets") - ) - input_tensors[name] = (values, offsets) - transformed = self.runner.run_workflow(input_tensors) - result = [pb_utils.Tensor(name, data) for name, data in transformed.items()] + try: + input_columns = self.workflow.input_schema.column_names + input_tensors = triton_request_to_tensor_table(request, input_columns) + transformed = self.runtime.transform(self.workflow.graph, input_tensors) + return tensor_table_to_triton_response(TensorTable(transformed)) + except BaseException as e: + import traceback + + raise RuntimeError( + f"Error: {type(e)} - {str(e)}, " + f"Traceback: {traceback.format_tb(e.__traceback__)}" + ) from e + + +def _parse_schema_features(schema): + schema_cats = schema.apply(ColumnSelector(tags=[Tags.CATEGORICAL])).column_names + schema_conts = schema.apply(ColumnSelector(tags=[Tags.CONTINUOUS])).column_names + + return schema_cats, schema_conts + - return pb_utils.InferenceResponse(result) +def _parse_mc_features(model_config): + mc_cats = json.loads(_get_param(model_config, "cats", "string_value", default="[]")) + mc_conts = json.loads(_get_param(model_config, "conts", "string_value", default="[]")) + return mc_cats, mc_conts -def _parse_input_dtypes(dtypes): - input_dtypes = {col: dtype for col, dtype in dtypes if not is_list_dtype(dtype)} - input_multihots = {col: dtype for col, dtype in dtypes if is_list_dtype(dtype)} - return input_dtypes, input_multihots +def _get_param(config, *args, default=None): + config_element = config["parameters"] + for key in args: + config_element = config_element.get(key, {}) + return config_element or default diff --git a/merlin/systems/workflow/__init__.py b/merlin/systems/workflow/__init__.py deleted file mode 100644 index 7abce527d..000000000 --- a/merlin/systems/workflow/__init__.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from merlin.schema import Tags - - -def get_embedding_sizes(source, output_dtypes=None): - """Returns a dictionary of embedding sizes from a workflow or workflow_node - - Parameters - ---------- - source : Workflow or ColumnSelector - Either a nvtabular Workflow or ColumnSelector object that we should use to find - embedding sizes - output_dtypes : dict, optional - Optional dictionary of column_name:dtype. If passing a workflow object dtypes - will be read from the workflow. This is used to figure out which columns - are multihot-categorical, which are split out by this function. If passed a workflow_node - and this parameter isn't set, you won't have multihot columns returned separately - """ - # TODO: do we need to distinguish multihot columns here? (if so why? ) - - # have to lazy import Workflow to avoid circular import errors - from nvtabular.workflow import Workflow - - output_node = source.output_node if isinstance(source, Workflow) else source - - if isinstance(source, Workflow): - output_dtypes = output_dtypes or source.output_dtypes - else: - # passed in a column group - output_dtypes = output_dtypes or {} - - output = {} - multihot_columns = set() - cats_schema = output_node.output_schema.select_by_tag(Tags.CATEGORICAL) - for col_name, col_schema in cats_schema.column_schemas.items(): - if col_schema.dtype and col_schema.is_list and col_schema.is_ragged: - # multi hot so remove from output and add to multihot - multihot_columns.add(col_name) - - embeddings_sizes = col_schema.properties.get("embedding_sizes", {}) - cardinality = embeddings_sizes["cardinality"] - dimensions = embeddings_sizes["dimension"] - output[col_name] = (cardinality, dimensions) - - # TODO: returning different return types like this (based off the presence - # of multihot features) is pretty janky. fix. - if not multihot_columns: - return output - - single_hots = {k: v for k, v in output.items() if k not in multihot_columns} - multi_hots = {k: v for k, v in output.items() if k in multihot_columns} - return single_hots, multi_hots diff --git a/merlin/systems/workflow/base.py b/merlin/systems/workflow/base.py deleted file mode 100644 index c3aeba11c..000000000 --- a/merlin/systems/workflow/base.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import functools -import json -import logging - -from merlin.core.dispatch import concat_columns -from merlin.dag import ColumnSelector, Supports -from merlin.schema import Tags -from merlin.systems.triton.conversions import convert_format -from merlin.table import TensorTable - -LOG = logging.getLogger("merlin-systems") - - -class WorkflowRunner: - def __init__(self, workflow, output_dtypes, model_config, model_device): - self.workflow = workflow - self.output_dtypes = output_dtypes - self.model_config = model_config - self.device = model_device - - output_schema = self.workflow.output_schema - - schema_cats = output_schema.apply(ColumnSelector(tags=[Tags.CATEGORICAL])).column_names - schema_conts = output_schema.apply(ColumnSelector(tags=[Tags.CONTINUOUS])).column_names - - mc_cats = json.loads(self._get_param(model_config, "cats", "string_value", default="[]")) - mc_conts = json.loads(self._get_param(model_config, "conts", "string_value", default="[]")) - - self.cats = mc_cats or schema_cats - self.conts = mc_conts or schema_conts - self.offsets = None - - workflow_outputs = set(workflow.output_schema.column_names) - requested_cols = set(self.cats + self.conts) - missing_cols = requested_cols - workflow_outputs - - if missing_cols: - raise ValueError( - f"The following columns were not found in the workflow's output: {missing_cols}" - ) - - # recurse over all column groups, initializing operators for inference pipeline - self._initialize_ops(self.workflow.output_node) - - def _initialize_ops(self, workflow_node, visited=None): - if visited is None: - visited = set() - - if workflow_node.op and hasattr(workflow_node.op, "inference_initialize"): - inference_op = workflow_node.op.inference_initialize( - workflow_node.selector, self.model_config - ) - if inference_op: - workflow_node.op = inference_op - - supported = workflow_node.op.supports - - # if we're running on the CPU only, mask off support for GPU data formats - if self.device == "CPU": - supported = functools.reduce( - lambda a, b: a | b, - (v for v in list(Supports) if v & supported and "CPU" in str(v)), - ) - # the 'supports' property is readonly, and we can't always attach a new property - # to some of the operators (C++ categorify etc). set on the workflow_node instead - workflow_node.inference_supports = supported - - for parent in workflow_node.parents_with_dependencies: - if parent not in visited: - visited.add(parent) - self._initialize_ops(parent, visited) - - def run_workflow(self, input_tensors): - # use our NVTabular workflow to transform the dataset - transformed, kind = self._transform_tensors(input_tensors, self.workflow.output_node) - - # if we don't have tensors in numpy format, convert back so that the we can return - # to triton - if kind != Supports.CPU_DICT_ARRAY: - transformed, kind = convert_format(transformed, kind, Supports.CPU_DICT_ARRAY) - - return TensorTable(transformed).to_dict() - - def _transform_tensors(self, input_tensors, workflow_node): - upstream_inputs = [] - - # Gather inputs from the parents and dependency nodes - if workflow_node.parents_with_dependencies: - for parent in workflow_node.parents_with_dependencies: - upstream_tensors, upstream_kind = self._transform_tensors(input_tensors, parent) - if upstream_tensors is not None and upstream_kind: - upstream_inputs.append((upstream_tensors, upstream_kind)) - - # Gather additional input columns from the original input tensors - if workflow_node.selector: - selector_columns = workflow_node.selector.names - to_remove = [] - for upstream_tensors, upstream_kind in upstream_inputs: - for col in selector_columns: - if col in upstream_tensors: - to_remove.append(col) - for col in set(to_remove): - selector_columns.remove(col) - - if selector_columns: - selected_tensors = {c: input_tensors[c] for c in selector_columns} - selected_kinds = Supports.CPU_DICT_ARRAY - upstream_inputs.append((selected_tensors, selected_kinds)) - - # Standardize the formats - tensors, kind = None, None - for upstream_tensors, upstream_kind in upstream_inputs: - if tensors is None: - tensors, kind = upstream_tensors, upstream_kind - else: - if kind != upstream_kind: - # we have multiple different kinds of data here (dataframe/array on cpu/gpu) - # we need to convert to a common format here first before concatenating. - op = workflow_node.op - if op and hasattr(op, "inference_supports"): - target_kind = op.inference_supports - else: - target_kind = Supports.CPU_DICT_ARRAY - # note : the 2nd convert_format call needs to be stricter in what the kind is - # (exact match rather than a bitmask of values) - tensors, kind = convert_format(tensors, kind, target_kind) - upstream_tensors, _ = convert_format(upstream_tensors, upstream_kind, kind) - - tensors = self.concat_tensors([tensors, upstream_tensors], kind) - - # Run the transform - if tensors is not None and kind and workflow_node.op: - try: - # if the op doesn't support the current kind - we need to convert - if ( - hasattr(workflow_node, "inference_supports") - and not workflow_node.inference_supports & kind - ): - tensors, kind = convert_format(tensors, kind, workflow_node.inference_supports) - - tensors = workflow_node.op.transform( - workflow_node.input_columns, - tensors, - ) - - except Exception: - LOG.exception("Failed to transform operator %s", workflow_node.op) - raise - - return tensors, kind - - def concat_tensors(self, tensors, kind): - if kind & (Supports.GPU_DATAFRAME | Supports.CPU_DATAFRAME): - return concat_columns(tensors) - else: - output = tensors[0] - for tensor in tensors[1:]: - output.update(tensor) - return output - - def _get_param(self, config, *args, default=None): - config_element = config["parameters"] - for key in args: - config_element = config_element.get(key, {}) - return config_element or default