Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Triton backend support #537

Merged
merged 21 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/examples/pre-process.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ ensure that any IO does not prevent utilization of the compute on your pod.
To do this, you can use the pre/post process methods on a Truss. These methods
can be defined like this:


```python
class Model:
def __init__: ...
def load(self, **kwargs) -> None: ...
def preprocess(self, request):
# Include any IO logic that happens _before_ predict here
...

def predict(self, request):
# Include the actual predict here
...
Expand Down Expand Up @@ -120,4 +121,3 @@ resources:
```

</RequestExample>

3 changes: 2 additions & 1 deletion truss/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def init(
target_directory_path = populate_target_directory(
config=config,
target_directory_path=target_directory,
populate_dirs=config.build.model_server is ModelServer.TrussServer,
populate_dirs=config.build.model_server
in [ModelServer.TrussServer, ModelServer.TRITON],
)

if trainable:
Expand Down
1 change: 1 addition & 0 deletions truss/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

TEMPLATES_DIR = pathlib.Path(CODE_DIR, "templates")
SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "server"
TRITON_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "triton"
TRAINING_JOB_WRAPPER_CODE_DIR_NAME = "training"
TRAINING_JOB_WRAPPER_CODE_DIR: pathlib.Path = (
TEMPLATES_DIR / TRAINING_JOB_WRAPPER_CODE_DIR_NAME
Expand Down
44 changes: 44 additions & 0 deletions truss/contexts/image_builder/serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME,
SYSTEM_PACKAGES_TXT_FILENAME,
TEMPLATES_DIR,
TRITON_SERVER_CODE_DIR,
)
from truss.contexts.image_builder.image_builder import ImageBuilder
from truss.contexts.image_builder.util import (
Expand Down Expand Up @@ -52,6 +53,46 @@
HF_ACCESS_TOKEN_SECRET_NAME = "hf_access_token"


def create_triton_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path):
_spec = TrussSpec(truss_dir)
if not build_dir.exists():
build_dir.mkdir(parents=True)

# The triton server expects a specific directory structure. We create this directory structure
# in the build directory. The structure is:
# build_dir
# ├── model # Name of "model", used during invocation
# │ └── 1 # Version of the model, used during invocation
# │ └── truss # User-defined code
# │ ├── config.yml
# │ ├── model.py
# │ ├── # other truss files
# │ └── model.py # Triton server code
# │ └── # other triton server files

model_repository_path = build_dir / "model"
user_truss_path = model_repository_path / "truss" # type: ignore[operator]
data_dir = model_repository_path / config.data_dir # type: ignore[operator]

copy_tree_path(TRITON_SERVER_CODE_DIR / "model", model_repository_path)
copy_tree_path(TRITON_SERVER_CODE_DIR / "root", build_dir)
copy_tree_path(truss_dir, user_truss_path)
copy_tree_path(
SHARED_SERVING_AND_TRAINING_CODE_DIR,
model_repository_path / SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME,
)

# Override config.yml
with (model_repository_path / "truss" / CONFIG_FILE).open("w") as config_file:
yaml.dump(config.to_dict(verbose=True), config_file)

# Download external data
download_external_data(_spec.external_data, data_dir)

(build_dir / REQUIREMENTS_TXT_FILENAME).write_text(_spec.requirements_txt)
(build_dir / SYSTEM_PACKAGES_TXT_FILENAME).write_text(_spec.system_packages_txt)


def split_gs_path(gs_path):
# Remove the 'gs://' prefix
path = gs_path.replace("gs://", "")
Expand Down Expand Up @@ -267,6 +308,9 @@ def prepare_image_build_dir(
elif config.build.model_server is ModelServer.VLLM:
create_vllm_build_dir(config, build_dir, truss_dir)
return
elif config.build.model_server is ModelServer.TRITON:
create_triton_build_dir(config, build_dir, truss_dir)
return

data_dir = build_dir / config.data_dir # type: ignore[operator]

Expand Down
58 changes: 58 additions & 0 deletions truss/templates/triton/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from pathlib import Path

import yaml
from triton_model_wrapper import TritonModelWrapper


class TritonPythonModel:
def __init__(self):
self._model_repository_path = None
self._config_path = None
self._config = None
self._model_wrapper: TritonModelWrapper = None

def _instantiate_model_wrapper(self, triton_config: dict) -> None:
"""
Instantiates the model wrapper class as well as the user-defined model class.

Args:
triton_config (dict): Triton configuration dictionary. This contains information about the model repository
path and the model version.

Returns:
None
"""
self._model_repository_path = (
Path(triton_config["model_repository"]) / triton_config["model_version"]
)
self._config_path = self._model_repository_path / "truss" / "config.yaml"
with open(self._config_path, encoding="utf-8") as config_file:
self._config = yaml.safe_load(config_file)
self._model_wrapper: TritonModelWrapper = TritonModelWrapper(self._config)
self._model_wrapper.instantiate()

def initialize(self, triton_config: dict) -> None:
"""
Instantiates the model wrapper class and loads the user's model. This function is called by Triton upon startup.

Args:
args (dict): Triton configuration dictionary. This contains information about the model repository path and
the model version as well as all the information defined in the config.pbtxt.

Returns:
None
"""
self._instantiate_model_wrapper(triton_config)
self._model_wrapper.load()

def execute(self, requests: list) -> list:
"""
Executes the user's model on Triton InferenceRequest objects. This function is called by Triton upon inference.

Args:
requests (InferenceRequest): Triton InferenceRequest object. This contains the input data for the model.

Returns:
InferenceResponse: Triton InferenceResponse object. This contains the output data from the model.
"""
return self._model_wrapper.predict(requests)
102 changes: 102 additions & 0 deletions truss/templates/triton/model/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import json
from typing import List, Type

import numpy as np
from pydantic import BaseModel
from utils.errors import InvalidModelResponseError
from utils.pydantic import PYTHON_TYPE_TO_NP_DTYPE, inspect_pydantic_model
from utils.triton import (
InferenceRequest,
InferenceResponse,
convert_tensor_to_python_type,
create_inference_response,
create_tensor,
get_input_tensor_by_name,
)


def transform_triton_to_pydantic(
triton_requests: List[InferenceRequest],
pydantic_type: Type[BaseModel],
) -> List[BaseModel]:
"""
Transforms a list of Triton requests into a list of Pydantic objects.

This function inspects the fields of the given Pydantic type and extracts
the corresponding tensors from the Triton requests. It then converts these
tensors into Python types and uses them to instantiate the Pydantic objects.

Args:
triton_requests (list): A list of Triton request objects.
pydantic_type (Type[BaseModel]): The Pydantic model type to which the Triton requests should be transformed.

Returns:
list: A list of Pydantic objects instantiated from the Triton requests.

Raises:
ValueError: If a tensor corresponding to a Pydantic field cannot be found in the Triton request.
TypeError: If a tensor corresponding to a Pydantic field has an unsupported type.
"""
fields = inspect_pydantic_model(pydantic_type)
results = []

for request in triton_requests:
data = {}
for field in fields:
field_name, field_type = field.values()
tensor = get_input_tensor_by_name(request, field_name)
data[field_name] = convert_tensor_to_python_type(tensor, field_type)
results.append(pydantic_type(**data))

return results


def transform_pydantic_to_triton(
pydantic_objects: List[BaseModel],
) -> List[InferenceResponse]:
"""
Transforms a list of Pydantic objects into a list of Triton inference responses.

This function iterates over the fields of each Pydantic object, determines the
appropriate tensor data type, and creates a tensor for each field. These tensors
are then used to create Triton inference responses.

Args:
pydantic_objects (list): A list of Pydantic objects to be transformed.

Returns:
list: A list of Triton inference responses created from the Pydantic objects.

Raises:
ValueError: If a Python type in the Pydantic object does not have a corresponding numpy dtype.
"""
results = []

if not isinstance(pydantic_objects, list):
raise InvalidModelResponseError(
"Truss did not return a list. Please ensure that your model returns a list \
of Pydantic objects even if the batch size is 1."
)

if not all(isinstance(obj, BaseModel) for obj in pydantic_objects):
raise InvalidModelResponseError(
"Truss returned a list of objects that are not Pydantic models. Please \
ensure that your model returns a list of Pydantic objects even if the \
batch size is 1."
)

for obj in pydantic_objects:
tensors = []
for field, value in obj.dict().items():
if isinstance(value, list):
dtype = PYTHON_TYPE_TO_NP_DTYPE[type(value[0])]
elif isinstance(value, dict):
dtype = PYTHON_TYPE_TO_NP_DTYPE[type(value)]
value = [json.dumps(value)]
else:
dtype = PYTHON_TYPE_TO_NP_DTYPE[type(value)]
value = [value]
tensors.append(create_tensor(field, np.array(value, dtype=dtype)))
results.append(create_inference_response(tensors))

return results
81 changes: 81 additions & 0 deletions truss/templates/triton/model/triton_model_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import importlib
import inspect
import sys
from pathlib import Path
from typing import Dict

import triton_python_backend_utils as pb_utils
from shared.secrets_resolver import SecretsResolver
from transform import transform_pydantic_to_triton, transform_triton_to_pydantic
from utils.errors import MissingInputClassError, MissingOutputClassError
from utils.signature import signature_accepts_keyword_arg


class TritonModelWrapper:
def __init__(self, config: Dict):
self._config = config
self._model = None
self._input_type = None
self._output_type = None

def instantiate(self) -> None:
"""
Instantiates the model class defined in the config.
"""
model_directory_prefix = Path("/app/model/1/truss")
data_dir = model_directory_prefix / "data"

if "bundled_packages_dir" in self._config:
bundled_packages_path = model_directory_prefix / "bundled_packages"
if bundled_packages_path.exists():
sys.path.append(str(bundled_packages_path))
model_module_name = str(
Path(self._config["model_class_filename"]).with_suffix("")
)
module = importlib.import_module(
f"truss.{self._config['model_module_dir']}.{model_module_name}"
)
model_class = getattr(module, self._config["model_class_name"])

input_type_name = self._config.get("input_type_name", "Input")
output_type_name = self._config.get("output_type_name", "Output")
self._input_type = getattr(module, input_type_name, None)
self._output_type = getattr(module, output_type_name, None)

if not self._input_type:
raise MissingInputClassError("Input type not defined in user model")
if not self._output_type:
raise MissingOutputClassError("Output type not defined in user model")

model_class_signature = inspect.signature(model_class)
model_init_params = {}
if signature_accepts_keyword_arg(model_class_signature, "config"):
model_init_params["config"] = self._config
if signature_accepts_keyword_arg(model_class_signature, "data_dir"):
model_init_params["data_dir"] = data_dir # type: ignore
if signature_accepts_keyword_arg(model_class_signature, "secrets"):
model_init_params["secrets"] = SecretsResolver.get_secrets(self._config)

self._model = model_class(**model_init_params)

def load(self) -> None:
"""
Invokes the underlying model's load method.
"""
if self._model is None:
raise ValueError("Model not instantiated")

self._model.load()

def predict(self, requests: list) -> list:
"""
Converts Triton requests to Pydantic objects, invokes the underlying model's predict method, and converts the
outputs to Triton objects.
"""
requests = transform_triton_to_pydantic(requests, self._input_type)
try:
outputs = self._model.predict(requests) # type: ignore
except Exception as e:
raise pb_utils.TritonModelException(e)

return transform_pydantic_to_triton(outputs)
37 changes: 37 additions & 0 deletions truss/templates/triton/model/utils/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
class Error(Exception):
"""Base Truss Error"""

def __init__(self, message: str):
super(Error, self).__init__(message)
self.message = message
self.type = type


class UnsupportedTypeError(Error):
"""Raised when a Pydantic field type is not supported"""

pass


class MissingFieldError(Error):
"""Raised when a Pydantic field is missing"""

pass


class MissingInputClassError(Error):
"""Raised when the user does not define an input class"""

pass


class MissingOutputClassError(Error):
"""Raised when the user does not define an output class"""

pass


class InvalidModelResponseError(Error):
"""Raised when the model returns an invalid response"""

pass
Loading