Skip to content

Commit

Permalink
Get rid of WorkflowRunner entirely by collapsing into Triton model
Browse files Browse the repository at this point in the history
  • Loading branch information
karlhigley committed Apr 10, 2023
1 parent 1b5ad13 commit 9789068
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 154 deletions.
54 changes: 48 additions & 6 deletions merlin/systems/triton/models/workflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,22 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import logging
import pathlib

import nvtabular
from merlin.dag import ColumnSelector
from merlin.schema import Tags
from merlin.systems.dag.runtimes.nvtabular.runtime import NVTabularServingRuntime
from merlin.systems.triton.conversions import (
tensor_table_to_triton_response,
triton_request_to_tensor_table,
)
from merlin.systems.triton.utils import triton_error_handling, triton_multi_request
from merlin.systems.workflow.base import WorkflowRunner
from merlin.table import TensorTable

LOG = logging.getLogger("merlin-systems")


class TritonPythonModel:
"""
Expand Down Expand Up @@ -65,15 +70,31 @@ def initialize(self, args):
repository_path = repository_path.parent.parent

workflow_path = repository_path / str(args["model_version"]) / "workflow"

model_device = args["model_instance_kind"]
self.runtime = NVTabularServingRuntime(model_device)

# Workflow instantiation
self.workflow = nvtabular.Workflow.load(str(workflow_path))
workflow = nvtabular.Workflow.load(str(workflow_path))
workflow.graph = self.runtime.convert(workflow.graph)
self.workflow = workflow

# Config loading and parsing
self.model_config = json.loads(args["model_config"])
model_config = json.loads(args["model_config"])

mc_cats, mc_conts = _parse_mc_features(model_config)
schema_cats, schema_conts = _parse_schema_features(self.workflow.output_schema)

self.cats = mc_cats or schema_cats
self.conts = mc_conts or schema_conts

self.runner = WorkflowRunner(self.workflow, self.model_config, model_device)
missing_cols = set(self.cats + self.conts) - set(self.workflow.output_schema.column_names)

if missing_cols:
raise ValueError(
"The following requested columns were not found in the workflow's output: "
f"{missing_cols}"
)

@triton_multi_request
@triton_error_handling
Expand All @@ -85,12 +106,33 @@ def execute(self, request):
try:
input_columns = self.workflow.input_schema.column_names
input_tensors = triton_request_to_tensor_table(request, input_columns)
output_tensors = self.runner.run_workflow(input_tensors)
return tensor_table_to_triton_response(TensorTable(output_tensors))
transformed = self.runtime.transform(self.workflow.graph, input_tensors)
return tensor_table_to_triton_response(TensorTable(transformed))
except BaseException as e:
import traceback

raise RuntimeError(
f"Error: {type(e)} - {str(e)}, "
f"Traceback: {traceback.format_tb(e.__traceback__)}"
) from e


def _parse_schema_features(schema):
schema_cats = schema.apply(ColumnSelector(tags=[Tags.CATEGORICAL])).column_names
schema_conts = schema.apply(ColumnSelector(tags=[Tags.CONTINUOUS])).column_names

return schema_cats, schema_conts


def _parse_mc_features(model_config):
mc_cats = json.loads(_get_param(model_config, "cats", "string_value", default="[]"))
mc_conts = json.loads(_get_param(model_config, "conts", "string_value", default="[]"))

return mc_cats, mc_conts


def _get_param(config, *args, default=None):
config_element = config["parameters"]
for key in args:
config_element = config_element.get(key, {})
return config_element or default
65 changes: 0 additions & 65 deletions merlin/systems/workflow/__init__.py

This file was deleted.

83 changes: 0 additions & 83 deletions merlin/systems/workflow/base.py

This file was deleted.

0 comments on commit 9789068

Please sign in to comment.