Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changing the default runtime from Ensemble to Executor #255

Merged
merged 7 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions merlin/systems/dag/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from merlin.core.protocols import Transformable
from merlin.dag import Graph
from merlin.systems.dag.runtimes.triton import TritonEnsembleRuntime
from merlin.systems.dag.runtimes.triton import TritonExecutorRuntime


class Ensemble:
Expand Down Expand Up @@ -74,7 +74,7 @@ def transform(self, transformable: Transformable, runtime=None):
Transformable
transformed data
"""
runtime = runtime or TritonEnsembleRuntime()
runtime = runtime or TritonExecutorRuntime()
return runtime.transform(self.graph, transformable)

def save(self, path):
Expand Down Expand Up @@ -149,5 +149,5 @@ def export(self, export_path, runtime=None, **kwargs):
Write out an ensemble model configuration directory. The exported
ensemble is designed for use with Triton Inference Server.
"""
runtime = runtime or TritonEnsembleRuntime()
runtime = runtime or TritonExecutorRuntime()
return runtime.export(self, export_path, **kwargs)
5 changes: 1 addition & 4 deletions merlin/systems/dag/runtimes/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,4 @@
# limitations under the License.
#
# flake8: noqa
from merlin.systems.dag.runtimes.triton.runtime import ( # noqa
TritonEnsembleRuntime,
TritonExecutorRuntime,
)
from merlin.systems.dag.runtimes.triton.runtime import TritonExecutorRuntime # noqa
159 changes: 1 addition & 158 deletions merlin/systems/dag/runtimes/triton/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
import tritonclient.grpc.model_config_pb2 as model_config
from google.protobuf import text_format

from merlin.core.protocols import Transformable
from merlin.dag import Graph, postorder_iter_nodes
from merlin.dag import postorder_iter_nodes
from merlin.systems.dag.ops import compute_dims
from merlin.systems.dag.ops.compat import (
cuml_ensemble,
Expand Down Expand Up @@ -76,162 +75,6 @@
TRITON_OP_TABLE[PredictPyTorch] = PredictPyTorchTriton


class TritonEnsembleRuntime(Runtime):
"""Runtime for Triton. Runs each operator in DAG as a separate model in a Triton Ensemble."""

def __init__(self):
super().__init__()
self.op_table = TRITON_OP_TABLE

def transform(self, graph: Graph, transformable: Transformable):
raise NotImplementedError("Transform handled by Triton")

def export(
self, ensemble, path: str, version: int = 1, name: str = None
) -> Tuple[model_config.ModelConfig, List[model_config.ModelConfig]]:
"""Exports an 'Ensemble' as a triton model repository.

Every operator is represented as a separate model,
loaded individually in Triton.

The entry point is the ensemble model with the name `name`, by default "ensemble_model"

Parameters
----------
ensemble : merlin.systems.dag.Ensemble
Systems ensemble to export
path : str
Path to directory where Triton model repository will be created.
version : int, optional
Version for Triton models created, by default 1
name : str, optional
The name of the ensemble triton model, by default "ensemble_model"

Returns
-------
Tuple[model_config.ModelConfig, List[model_config.ModelConfig]]
Tuple of ensemble config and list of non-python backend model configs
"""
name = name or "ensemble_model"
# Build node id lookup table

nodes = list(postorder_iter_nodes(ensemble.graph.output_node))

for node in nodes:
if type(node.op) in self.op_table:
node.op = self.op_table[type(node.op)](node.op)

node_id_table, num_nodes = _create_node_table(nodes, "ensemble")

nodes = nodes or []
node_id_table = node_id_table or {}

# Create ensemble config
ensemble_config = model_config.ModelConfig(
name=name,
platform="ensemble",
# max_batch_size=configs[0].max_batch_size
)

for _, col_schema in ensemble.graph.input_schema.column_schemas.items():
add_model_param(
ensemble_config.input,
model_config.ModelInput,
col_schema,
compute_dims(col_schema),
)

for _, col_schema in ensemble.graph.output_schema.column_schemas.items():
add_model_param(
ensemble_config.output,
model_config.ModelOutput,
col_schema,
compute_dims(col_schema),
)

node_configs = []
for node in nodes:
if node.exportable("ensemble"):
node_id = node_id_table.get(node, None)
node_name = f"{node_id}_{node.export_name}"

found = False
for step in ensemble_config.ensemble_scheduling.step:
if step.model_name == node_name:
found = True
if found:
continue

node_config = node.export(
path, node_id=node_id, version=version, backend="ensemble"
)
if node_config is not None:
node_configs.append(node_config)

config_step = model_config.ModelEnsembling.Step(
model_name=node_name, model_version=-1
)

for input_col_name, input_col_schema in node.input_schema.column_schemas.items():
source = self._find_column_source(
node.parents_with_dependencies, input_col_name, "ensemble"
)
source_id = node_id_table.get(source, None)
in_suffix = f"_{source_id}" if source_id is not None else ""

if input_col_schema.is_list and input_col_schema.is_ragged:
config_step.input_map[input_col_name + "__values"] = (
input_col_name + "__values" + in_suffix
)
config_step.input_map[input_col_name + "__lengths"] = (
input_col_name + "__lengths" + in_suffix
)
else:
config_step.input_map[input_col_name] = input_col_name + in_suffix

for output_col_name, output_col_schema in node.output_schema.column_schemas.items():
out_suffix = (
f"_{node_id}" if node_id is not None and node_id < num_nodes - 1 else ""
)

if output_col_schema.is_list and output_col_schema.is_ragged:
config_step.output_map[output_col_name + "__values"] = (
output_col_name + "__values" + out_suffix
)
config_step.output_map[output_col_name + "__lengths"] = (
output_col_name + "__lengths" + out_suffix
)
else:
config_step.output_map[output_col_name] = output_col_name + out_suffix

ensemble_config.ensemble_scheduling.step.append(config_step)

# Write the ensemble config file
ensemble_path = os.path.join(path, name)
os.makedirs(ensemble_path, exist_ok=True)
os.makedirs(os.path.join(ensemble_path, str(version)), exist_ok=True)

config_path = os.path.join(ensemble_path, "config.pbtxt")
with open(config_path, "w", encoding="utf-8") as o:
text_format.PrintMessage(ensemble_config, o)

return (ensemble_config, node_configs)

def _find_column_source(self, upstream_nodes, column_name, backend):
source_node = None
for upstream_node in upstream_nodes:
if column_name in upstream_node.output_columns.names:
source_node = upstream_node
break

if source_node and not source_node.exportable(backend):
return self._find_column_source(
source_node.parents_with_dependencies, column_name, backend
)
else:
return source_node


class TritonExecutorRuntime(Runtime):
"""Runtime for Triton.
This will run the DAG in a single Triton model and call out to other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from merlin.schema import ColumnSchema, Schema
from merlin.systems.dag.ensemble import Ensemble
from merlin.systems.dag.ops.fil import PredictForest
from merlin.systems.dag.runtimes.triton import TritonEnsembleRuntime, TritonExecutorRuntime
from merlin.systems.dag.runtimes.triton import TritonExecutorRuntime
from merlin.systems.triton.utils import run_ensemble_on_tritonserver

sklearn_datasets = pytest.importorskip("sklearn.datasets")
Expand All @@ -23,7 +23,6 @@
@pytest.mark.parametrize(
["runtime", "model_name", "expected_model_name"],
[
(TritonEnsembleRuntime(), None, "ensemble_model"),
(TritonExecutorRuntime(), None, "executor_model"),
],
)
Expand Down Expand Up @@ -68,7 +67,6 @@ def test_lightgbm_regressor_forest_inference(runtime, model_name, expected_model
@pytest.mark.parametrize(
["runtime", "model_name", "expected_model_name"],
[
(TritonEnsembleRuntime(), None, "ensemble_model"),
(TritonExecutorRuntime(), None, "executor_model"),
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from merlin.schema import ColumnSchema, Schema
from merlin.systems.dag.ensemble import Ensemble
from merlin.systems.dag.ops.fil import PredictForest
from merlin.systems.dag.runtimes.triton import TritonEnsembleRuntime, TritonExecutorRuntime
from merlin.systems.dag.runtimes.triton import TritonExecutorRuntime
from merlin.systems.triton.utils import run_ensemble_on_tritonserver

sklearn_datasets = pytest.importorskip("sklearn.datasets")
Expand All @@ -23,7 +23,6 @@
@pytest.mark.parametrize(
["runtime", "model_name", "expected_model_name"],
[
(TritonEnsembleRuntime(), None, "ensemble_model"),
(TritonExecutorRuntime(), None, "executor_model"),
],
)
Expand Down Expand Up @@ -69,7 +68,6 @@ def test_sklearn_regressor_forest_inference(runtime, model_name, expected_model_
@pytest.mark.parametrize(
["runtime", "model_name", "expected_model_name"],
[
(TritonEnsembleRuntime(), None, "ensemble_model"),
(TritonExecutorRuntime(), None, "executor_model"),
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from merlin.schema import ColumnSchema, Schema
from merlin.systems.dag.ensemble import Ensemble
from merlin.systems.dag.ops.fil import PredictForest
from merlin.systems.dag.runtimes.triton import TritonEnsembleRuntime, TritonExecutorRuntime
from merlin.systems.dag.runtimes.triton import TritonExecutorRuntime
from merlin.systems.triton.utils import run_ensemble_on_tritonserver

sklearn_datasets = pytest.importorskip("sklearn.datasets")
Expand All @@ -23,7 +23,6 @@
@pytest.mark.parametrize(
["runtime", "model_name", "expected_model_name"],
[
(TritonEnsembleRuntime(), None, "ensemble_model"),
(TritonExecutorRuntime(), None, "executor_model"),
],
)
Expand Down Expand Up @@ -68,7 +67,6 @@ def test_xgboost_regressor_forest_inference(runtime, model_name, expected_model_
@pytest.mark.parametrize(
["runtime", "model_name", "expected_model_name"],
[
(TritonEnsembleRuntime(), None, "ensemble_model"),
(TritonExecutorRuntime(), None, "executor_model"),
],
)
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/systems/dag/runtimes/triton/ops/torch/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from merlin.schema import ColumnSchema, Schema
from merlin.systems.dag.ensemble import Ensemble
from merlin.systems.dag.runtimes.triton import TritonEnsembleRuntime, TritonExecutorRuntime
from merlin.systems.dag.runtimes.triton import TritonExecutorRuntime
from merlin.systems.triton.utils import run_triton_server

TRITON_SERVER_PATH = find_executable("tritonserver")
Expand Down Expand Up @@ -161,7 +161,6 @@ def test_torch_backend(tmpdir):
@pytest.mark.parametrize(
["runtime", "model_name", "expected_model_name"],
[
(TritonEnsembleRuntime(), None, "ensemble_model"),
(TritonExecutorRuntime(), None, "executor_model"),
],
)
Expand Down Expand Up @@ -211,7 +210,6 @@ def test_pytorch_op_serving(
@pytest.mark.parametrize(
["runtime", "model_name", "expected_model_name"],
[
(TritonEnsembleRuntime(), None, "ensemble_model"),
(TritonExecutorRuntime(), None, "executor_model"),
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest
from tritonclient import grpc as grpcclient

from merlin.systems.dag.runtimes.triton import TritonEnsembleRuntime, TritonExecutorRuntime
from merlin.systems.dag.runtimes.triton import TritonExecutorRuntime
from merlin.systems.triton.utils import run_triton_server
from nvtabular import Workflow
from nvtabular import ops as wf_ops
Expand All @@ -36,7 +36,6 @@
@pytest.mark.parametrize(
["runtime", "model_name", "expected_model_name"],
[
(TritonEnsembleRuntime(), None, "ensemble_model"),
(TritonExecutorRuntime(), None, "executor_model"),
],
)
Expand Down
4 changes: 1 addition & 3 deletions tests/unit/systems/dag/runtimes/triton/test_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from merlin.systems.dag import DictArray
from merlin.systems.dag.ensemble import Ensemble
from merlin.systems.dag.ops.session_filter import FilterCandidates
from merlin.systems.dag.runtimes.triton import TritonEnsembleRuntime, TritonExecutorRuntime
from merlin.systems.dag.runtimes.triton import TritonExecutorRuntime
from merlin.systems.triton.utils import run_ensemble_on_tritonserver

triton = pytest.importorskip("merlin.systems.triton")
Expand All @@ -39,8 +39,6 @@
["runtime", "model_name", "expected_model_name"],
[
(None, None, "ensemble_model"),
(TritonEnsembleRuntime(), None, "ensemble_model"),
(TritonEnsembleRuntime(), "triton_model", "triton_model"),
(TritonExecutorRuntime(), None, "executor_model"),
(TritonExecutorRuntime(), "triton_model", "triton_model"),
],
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/systems/ops/implicit/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
tritonclient = pytest.importorskip("tritonclient")
grpcclient = pytest.importorskip("tritonclient.grpc")

from merlin.systems.dag.runtimes.triton import TritonEnsembleRuntime, TritonExecutorRuntime # noqa
from merlin.systems.dag.runtimes.triton import TritonExecutorRuntime # noqa
from merlin.systems.triton.utils import run_ensemble_on_tritonserver # noqa


@pytest.mark.skipif(not TRITON_SERVER_PATH, reason="triton server not found")
@pytest.mark.parametrize("runtime", [None, TritonEnsembleRuntime(), TritonExecutorRuntime()])
@pytest.mark.parametrize("runtime", [None, TritonExecutorRuntime()])
def test_implicit_in_triton_executor_model(tmpdir, runtime):
model = implicit.bpr.BayesianPersonalizedRanking()
n = 100
Expand Down