diff --git a/docs/examples/pre-process.mdx b/docs/examples/pre-process.mdx index b9a2b376b..942208f9d 100644 --- a/docs/examples/pre-process.mdx +++ b/docs/examples/pre-process.mdx @@ -17,6 +17,7 @@ ensure that any IO does not prevent utilization of the compute on your pod. To do this, you can use the pre/post process methods on a Truss. These methods can be defined like this: + ```python class Model: def __init__: ... @@ -24,7 +25,7 @@ class Model: def preprocess(self, request): # Include any IO logic that happens _before_ predict here ... - + def predict(self, request): # Include the actual predict here ... @@ -120,4 +121,3 @@ resources: ``` - diff --git a/truss/build.py b/truss/build.py index 09440b7b6..dec9a92d6 100644 --- a/truss/build.py +++ b/truss/build.py @@ -244,7 +244,8 @@ def init( target_directory_path = populate_target_directory( config=config, target_directory_path=target_directory, - populate_dirs=config.build.model_server is ModelServer.TrussServer, + populate_dirs=config.build.model_server + in [ModelServer.TrussServer, ModelServer.TRITON], ) if trainable: diff --git a/truss/constants.py b/truss/constants.py index c7258164b..1e55954b9 100644 --- a/truss/constants.py +++ b/truss/constants.py @@ -16,6 +16,7 @@ TEMPLATES_DIR = pathlib.Path(CODE_DIR, "templates") SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "server" +TRITON_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "triton" TRAINING_JOB_WRAPPER_CODE_DIR_NAME = "training" TRAINING_JOB_WRAPPER_CODE_DIR: pathlib.Path = ( TEMPLATES_DIR / TRAINING_JOB_WRAPPER_CODE_DIR_NAME diff --git a/truss/contexts/image_builder/serving_image_builder.py b/truss/contexts/image_builder/serving_image_builder.py index 1a24896ab..690f8926f 100644 --- a/truss/contexts/image_builder/serving_image_builder.py +++ b/truss/contexts/image_builder/serving_image_builder.py @@ -17,6 +17,7 @@ SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME, SYSTEM_PACKAGES_TXT_FILENAME, TEMPLATES_DIR, + TRITON_SERVER_CODE_DIR, ) from truss.contexts.image_builder.image_builder import ImageBuilder from truss.contexts.image_builder.util import ( @@ -52,6 +53,46 @@ HF_ACCESS_TOKEN_SECRET_NAME = "hf_access_token" +def create_triton_build_dir(config: TrussConfig, build_dir: Path, truss_dir: Path): + _spec = TrussSpec(truss_dir) + if not build_dir.exists(): + build_dir.mkdir(parents=True) + + # The triton server expects a specific directory structure. We create this directory structure + # in the build directory. The structure is: + # build_dir + # ├── model # Name of "model", used during invocation + # │ └── 1 # Version of the model, used during invocation + # │ └── truss # User-defined code + # │ ├── config.yml + # │ ├── model.py + # │ ├── # other truss files + # │ └── model.py # Triton server code + # │ └── # other triton server files + + model_repository_path = build_dir / "model" + user_truss_path = model_repository_path / "truss" # type: ignore[operator] + data_dir = model_repository_path / config.data_dir # type: ignore[operator] + + copy_tree_path(TRITON_SERVER_CODE_DIR / "model", model_repository_path) + copy_tree_path(TRITON_SERVER_CODE_DIR / "root", build_dir) + copy_tree_path(truss_dir, user_truss_path) + copy_tree_path( + SHARED_SERVING_AND_TRAINING_CODE_DIR, + model_repository_path / SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME, + ) + + # Override config.yml + with (model_repository_path / "truss" / CONFIG_FILE).open("w") as config_file: + yaml.dump(config.to_dict(verbose=True), config_file) + + # Download external data + download_external_data(_spec.external_data, data_dir) + + (build_dir / REQUIREMENTS_TXT_FILENAME).write_text(_spec.requirements_txt) + (build_dir / SYSTEM_PACKAGES_TXT_FILENAME).write_text(_spec.system_packages_txt) + + def split_gs_path(gs_path): # Remove the 'gs://' prefix path = gs_path.replace("gs://", "") @@ -267,6 +308,9 @@ def prepare_image_build_dir( elif config.build.model_server is ModelServer.VLLM: create_vllm_build_dir(config, build_dir, truss_dir) return + elif config.build.model_server is ModelServer.TRITON: + create_triton_build_dir(config, build_dir, truss_dir) + return data_dir = build_dir / config.data_dir # type: ignore[operator] diff --git a/truss/templates/triton/model/model.py b/truss/templates/triton/model/model.py new file mode 100644 index 000000000..298cda1f8 --- /dev/null +++ b/truss/templates/triton/model/model.py @@ -0,0 +1,58 @@ +from pathlib import Path + +import yaml +from triton_model_wrapper import TritonModelWrapper + + +class TritonPythonModel: + def __init__(self): + self._model_repository_path = None + self._config_path = None + self._config = None + self._model_wrapper: TritonModelWrapper = None + + def _instantiate_model_wrapper(self, triton_config: dict) -> None: + """ + Instantiates the model wrapper class as well as the user-defined model class. + + Args: + triton_config (dict): Triton configuration dictionary. This contains information about the model repository + path and the model version. + + Returns: + None + """ + self._model_repository_path = ( + Path(triton_config["model_repository"]) / triton_config["model_version"] + ) + self._config_path = self._model_repository_path / "truss" / "config.yaml" + with open(self._config_path, encoding="utf-8") as config_file: + self._config = yaml.safe_load(config_file) + self._model_wrapper: TritonModelWrapper = TritonModelWrapper(self._config) + self._model_wrapper.instantiate() + + def initialize(self, triton_config: dict) -> None: + """ + Instantiates the model wrapper class and loads the user's model. This function is called by Triton upon startup. + + Args: + args (dict): Triton configuration dictionary. This contains information about the model repository path and + the model version as well as all the information defined in the config.pbtxt. + + Returns: + None + """ + self._instantiate_model_wrapper(triton_config) + self._model_wrapper.load() + + def execute(self, requests: list) -> list: + """ + Executes the user's model on Triton InferenceRequest objects. This function is called by Triton upon inference. + + Args: + requests (InferenceRequest): Triton InferenceRequest object. This contains the input data for the model. + + Returns: + InferenceResponse: Triton InferenceResponse object. This contains the output data from the model. + """ + return self._model_wrapper.predict(requests) diff --git a/truss/templates/triton/model/transform.py b/truss/templates/triton/model/transform.py new file mode 100644 index 000000000..a9d6ee521 --- /dev/null +++ b/truss/templates/triton/model/transform.py @@ -0,0 +1,102 @@ +import json +from typing import List, Type + +import numpy as np +from pydantic import BaseModel +from utils.errors import InvalidModelResponseError +from utils.pydantic import PYTHON_TYPE_TO_NP_DTYPE, inspect_pydantic_model +from utils.triton import ( + InferenceRequest, + InferenceResponse, + convert_tensor_to_python_type, + create_inference_response, + create_tensor, + get_input_tensor_by_name, +) + + +def transform_triton_to_pydantic( + triton_requests: List[InferenceRequest], + pydantic_type: Type[BaseModel], +) -> List[BaseModel]: + """ + Transforms a list of Triton requests into a list of Pydantic objects. + + This function inspects the fields of the given Pydantic type and extracts + the corresponding tensors from the Triton requests. It then converts these + tensors into Python types and uses them to instantiate the Pydantic objects. + + Args: + triton_requests (list): A list of Triton request objects. + pydantic_type (Type[BaseModel]): The Pydantic model type to which the Triton requests should be transformed. + + Returns: + list: A list of Pydantic objects instantiated from the Triton requests. + + Raises: + ValueError: If a tensor corresponding to a Pydantic field cannot be found in the Triton request. + TypeError: If a tensor corresponding to a Pydantic field has an unsupported type. + """ + fields = inspect_pydantic_model(pydantic_type) + results = [] + + for request in triton_requests: + data = {} + for field in fields: + field_name, field_type = field.values() + tensor = get_input_tensor_by_name(request, field_name) + data[field_name] = convert_tensor_to_python_type(tensor, field_type) + results.append(pydantic_type(**data)) + + return results + + +def transform_pydantic_to_triton( + pydantic_objects: List[BaseModel], +) -> List[InferenceResponse]: + """ + Transforms a list of Pydantic objects into a list of Triton inference responses. + + This function iterates over the fields of each Pydantic object, determines the + appropriate tensor data type, and creates a tensor for each field. These tensors + are then used to create Triton inference responses. + + Args: + pydantic_objects (list): A list of Pydantic objects to be transformed. + + Returns: + list: A list of Triton inference responses created from the Pydantic objects. + + Raises: + ValueError: If a Python type in the Pydantic object does not have a corresponding numpy dtype. + """ + results = [] + + if not isinstance(pydantic_objects, list): + raise InvalidModelResponseError( + "Truss did not return a list. Please ensure that your model returns a list \ + of Pydantic objects even if the batch size is 1." + ) + + if not all(isinstance(obj, BaseModel) for obj in pydantic_objects): + raise InvalidModelResponseError( + "Truss returned a list of objects that are not Pydantic models. Please \ + ensure that your model returns a list of Pydantic objects even if the \ + batch size is 1." + ) + + for obj in pydantic_objects: + tensors = [] + for field, value in obj.dict().items(): + if isinstance(value, list): + dtype = PYTHON_TYPE_TO_NP_DTYPE[type(value[0])] + elif isinstance(value, dict): + dtype = PYTHON_TYPE_TO_NP_DTYPE[type(value)] + value = [json.dumps(value)] + else: + dtype = PYTHON_TYPE_TO_NP_DTYPE[type(value)] + value = [value] + tensors.append(create_tensor(field, np.array(value, dtype=dtype))) + results.append(create_inference_response(tensors)) + + return results diff --git a/truss/templates/triton/model/triton_model_wrapper.py b/truss/templates/triton/model/triton_model_wrapper.py new file mode 100644 index 000000000..0675c9358 --- /dev/null +++ b/truss/templates/triton/model/triton_model_wrapper.py @@ -0,0 +1,81 @@ +import importlib +import inspect +import sys +from pathlib import Path +from typing import Dict + +import triton_python_backend_utils as pb_utils +from shared.secrets_resolver import SecretsResolver +from transform import transform_pydantic_to_triton, transform_triton_to_pydantic +from utils.errors import MissingInputClassError, MissingOutputClassError +from utils.signature import signature_accepts_keyword_arg + + +class TritonModelWrapper: + def __init__(self, config: Dict): + self._config = config + self._model = None + self._input_type = None + self._output_type = None + + def instantiate(self) -> None: + """ + Instantiates the model class defined in the config. + """ + model_directory_prefix = Path("/app/model/1/truss") + data_dir = model_directory_prefix / "data" + + if "bundled_packages_dir" in self._config: + bundled_packages_path = model_directory_prefix / "bundled_packages" + if bundled_packages_path.exists(): + sys.path.append(str(bundled_packages_path)) + model_module_name = str( + Path(self._config["model_class_filename"]).with_suffix("") + ) + module = importlib.import_module( + f"truss.{self._config['model_module_dir']}.{model_module_name}" + ) + model_class = getattr(module, self._config["model_class_name"]) + + input_type_name = self._config.get("input_type_name", "Input") + output_type_name = self._config.get("output_type_name", "Output") + self._input_type = getattr(module, input_type_name, None) + self._output_type = getattr(module, output_type_name, None) + + if not self._input_type: + raise MissingInputClassError("Input type not defined in user model") + if not self._output_type: + raise MissingOutputClassError("Output type not defined in user model") + + model_class_signature = inspect.signature(model_class) + model_init_params = {} + if signature_accepts_keyword_arg(model_class_signature, "config"): + model_init_params["config"] = self._config + if signature_accepts_keyword_arg(model_class_signature, "data_dir"): + model_init_params["data_dir"] = data_dir # type: ignore + if signature_accepts_keyword_arg(model_class_signature, "secrets"): + model_init_params["secrets"] = SecretsResolver.get_secrets(self._config) + + self._model = model_class(**model_init_params) + + def load(self) -> None: + """ + Invokes the underlying model's load method. + """ + if self._model is None: + raise ValueError("Model not instantiated") + + self._model.load() + + def predict(self, requests: list) -> list: + """ + Converts Triton requests to Pydantic objects, invokes the underlying model's predict method, and converts the + outputs to Triton objects. + """ + requests = transform_triton_to_pydantic(requests, self._input_type) + try: + outputs = self._model.predict(requests) # type: ignore + except Exception as e: + raise pb_utils.TritonModelException(e) + + return transform_pydantic_to_triton(outputs) diff --git a/truss/templates/triton/model/utils/errors.py b/truss/templates/triton/model/utils/errors.py new file mode 100644 index 000000000..a6327579f --- /dev/null +++ b/truss/templates/triton/model/utils/errors.py @@ -0,0 +1,37 @@ +class Error(Exception): + """Base Truss Error""" + + def __init__(self, message: str): + super(Error, self).__init__(message) + self.message = message + self.type = type + + +class UnsupportedTypeError(Error): + """Raised when a Pydantic field type is not supported""" + + pass + + +class MissingFieldError(Error): + """Raised when a Pydantic field is missing""" + + pass + + +class MissingInputClassError(Error): + """Raised when the user does not define an input class""" + + pass + + +class MissingOutputClassError(Error): + """Raised when the user does not define an output class""" + + pass + + +class InvalidModelResponseError(Error): + """Raised when the model returns an invalid response""" + + pass diff --git a/truss/templates/triton/model/utils/pydantic.py b/truss/templates/triton/model/utils/pydantic.py new file mode 100644 index 000000000..3d09c889a --- /dev/null +++ b/truss/templates/triton/model/utils/pydantic.py @@ -0,0 +1,73 @@ +from typing import Any, List, Type, Union + +import numpy as np +from pydantic import BaseModel +from utils.errors import UnsupportedTypeError + +# This is a broad type that includes built-in Python types and types that might be +# returned by Pydantic functions like conlist. +FieldAnnotationType = Union[ + Type[int], Type[float], Type[str], Type[bool], Type[dict], List[Any] +] + +# Dictionary to specify the numpy dtype when converting from Python types to Triton tensors +PYTHON_TYPE_TO_NP_DTYPE = { + int: np.int32, + float: np.float32, + str: np.dtype(object), + bool: np.bool_, + dict: np.dtype(object), +} + + +def is_conlist(pydantic_field_annotation: FieldAnnotationType) -> bool: + """Checks if a Pydantic field annotation is a conlist.""" + # Annotations that are not directly a Python type have an __origin__ attribute + if not hasattr(pydantic_field_annotation, "__origin__"): + return False + return pydantic_field_annotation.__origin__ == list + + +def get_pydantic_field_type(pydantic_field_annotation: FieldAnnotationType) -> str: + """ + Returns the type of a Pydantic field via it's annotation. The annotation can be a + type or a type with constraints (e.g. conlist). For the following types on a Pydantic + field, the pydantic_field_annotation is + + - str: str + - int: int + - float: float + - bool: bool + - list: list + - conlist: typing.List[...] where ... is the type of the list elements + + Args: + pydantic_field (FieldAnnotationType): The Pydantic field. + + Returns: + The type (str) of the Pydantic field. + """ + if isinstance(pydantic_field_annotation, type): + return pydantic_field_annotation.__name__ + elif is_conlist(pydantic_field_annotation): + return "list" + else: + raise UnsupportedTypeError( + f"Unsupported Pydantic field annotation type: {pydantic_field_annotation}" + ) + + +def inspect_pydantic_model(pydantic_class: Type[BaseModel]) -> list: + """ + Inspects a Pydantic model and returns a list of fields and their types. + + Args: + pydantic_class (Type[BaseModel]): The Pydantic model class. + + Returns: + A list of dicts containing the field name and type. + """ + return [ + {"param": field_name, "type": get_pydantic_field_type(field_info.annotation)} + for field_name, field_info in pydantic_class.__fields__.items() # type: ignore + ] diff --git a/truss/templates/triton/model/utils/signature.py b/truss/templates/triton/model/utils/signature.py new file mode 100644 index 000000000..277b1d5f2 --- /dev/null +++ b/truss/templates/triton/model/utils/signature.py @@ -0,0 +1,14 @@ +import inspect + + +def _signature_accepts_kwargs(signature: inspect.Signature) -> bool: + """Checks if a signature accepts **kwargs.""" + for param in signature.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + return True + return False + + +def signature_accepts_keyword_arg(signature: inspect.Signature, kwarg: str) -> bool: + """Checks if a signature accepts a keyword argument with the given name.""" + return kwarg in signature.parameters or _signature_accepts_kwargs(signature) diff --git a/truss/templates/triton/model/utils/triton.py b/truss/templates/triton/model/utils/triton.py new file mode 100644 index 000000000..6b3be8180 --- /dev/null +++ b/truss/templates/triton/model/utils/triton.py @@ -0,0 +1,77 @@ +import json +from typing import Any, List + +import numpy as np +from utils.errors import UnsupportedTypeError + + +class Tensor: + def as_numpy(self) -> np.ndarray: + return # type: ignore + + +class InferenceRequest: + pass + + +class InferenceResponse: + pass + + +def get_numpy_item( + tensor: Tensor, +) -> str: + """Returns the item of a numpy array as a Python type.""" + item = tensor.as_numpy().item() + return item.decode("utf-8") if isinstance(item, bytes) else item + + +def convert_tensor_to_python_type(tensor: Tensor, dtype: str) -> Any: + """Converts a Triton tensor to a Python type.""" + + def _convert_tensor_to_dict( + tensor: Tensor, + ) -> dict: + item = get_numpy_item(tensor) + try: + coerced_dict = json.loads(item) + except json.decoder.JSONDecodeError: + raise ValueError( + f"Could not decode string to dict: {item}. Please ensure that the string is valid JSON." + ) + return coerced_dict + + if dtype == "dict": + return _convert_tensor_to_dict(tensor) + elif dtype == "list": + return tensor.as_numpy().flatten().tolist() + elif dtype in ["str", "int", "float", "bool"]: + return get_numpy_item(tensor) + else: + raise UnsupportedTypeError(f"Unsupported dtype: {dtype}") + + +def get_input_tensor_by_name(obj: InferenceRequest, name: str) -> Tensor: + """Extracts a tensor from a Triton InferenceRequest by name.""" + from triton_python_backend_utils import get_input_tensor_by_name + + x = get_input_tensor_by_name(obj, name) + if x is None: + raise ValueError(f"Input tensor {name} not found in request.") + return x + + +def create_tensor(name: str, value: Any) -> Tensor: + """Creates a Triton tensor.""" + from triton_python_backend_utils import Tensor + + return Tensor(name, value) + + +def create_inference_response( + objects: List[Tensor], +) -> InferenceResponse: + """Creates a Triton InferenceResponse from a list of Triton tensors.""" + from triton_python_backend_utils import InferenceResponse + + return InferenceResponse(objects) diff --git a/truss/templates/triton/root/Dockerfile b/truss/templates/triton/root/Dockerfile new file mode 100644 index 000000000..d48e6f676 --- /dev/null +++ b/truss/templates/triton/root/Dockerfile @@ -0,0 +1,49 @@ +FROM nvcr.io/nvidia/tritonserver:23.07-py3 + +ENV PYTHONUNBUFFERED True +ENV DEBIAN_FRONTEND=noninteractive +ENV APP_HOME /app + +WORKDIR $APP_HOME + +RUN pip install --upgrade pip --no-cache-dir \ + && pip install pydantic pyyaml jinja2 --no-cache-dir \ + && rm -rf /root/.cache/pip + +# Install fixed system dependencies +RUN apt update && \ + apt install -y bash \ + build-essential \ + git \ + curl \ + ca-certificates \ + software-properties-common \ + nginx \ + supervisor \ + && apt-get autoremove -y \ + && apt-get clean -y \ + && rm -rf /var/lib/apt/lists/* + +# Copy system_packages.txt and install them +COPY ./system_packages.txt /app/system_packages.txt +RUN apt-get update && apt-get install --yes --no-install-recommends $(cat system_packages.txt) \ + && apt-get autoremove -y \ + && apt-get clean -y \ + && rm -rf /var/lib/apt/lists/* + +# Copy model and other static files +COPY ./model /app/model/1 +COPY ./generate_config.py /app/generate_config.py +COPY ./proxy.conf /etc/nginx/conf.d/proxy.conf +COPY ./supervisord.conf /etc/supervisor/conf.d/supervisord.conf +COPY ./config.pbtxt.jinja /app/config.pbtxt.jinja + +# Copy requirements.txt and install Python packages +COPY ./requirements.txt /app/requirements.txt +RUN pip install -r requirements.txt --no-cache-dir && rm -rf /root/.cache/pip + +# Generate config.pbtxt +RUN python3 /app/generate_config.py + +ENV SERVER_START_CMD /usr/bin/supervisord +ENTRYPOINT ["/usr/bin/supervisord"] diff --git a/truss/templates/triton/root/config.pbtxt.jinja b/truss/templates/triton/root/config.pbtxt.jinja new file mode 100644 index 000000000..5075a8488 --- /dev/null +++ b/truss/templates/triton/root/config.pbtxt.jinja @@ -0,0 +1,39 @@ +{# +A sample config.pbtxt template for a model that uses the Python backend. +https://github.com/triton-inference-server/backend/blob/504abc9dc74f2f9aa907e561620c411ffbdb3f3c/examples/model_repos/minimal_models/batching/config.pbtxt#L4 +#} + +name: "model" +backend: "python" +max_batch_size: {{ max_batch_size }} + +{% if dynamic_batching_delay_microseconds is defined %} +dynamic_batching { max_queue_delay_microseconds: {{dynamic_batching_delay_microseconds}} } +{% endif %} + +input [ +{% for inp in inputs %} + { + name: "{{ inp.name }}" + data_type: {{ inp.triton_type }} + dims: [{{ inp.dims|join(', ') }}] + }{% if not loop.last %},{% endif %} +{% endfor %} +] + +output [ +{% for out in outputs %} + { + name: "{{ out.name }}" + data_type: {{ out.triton_type }} + dims: [{{ out.dims|join(', ') }}] + }{% if not loop.last %},{% endif %} +{% endfor %} +] + +instance_group [ + { + kind: {% if is_gpu %}KIND_GPU{% else %}KIND_CPU{% endif %} + count: {{ num_replicas }} + } +] diff --git a/truss/templates/triton/root/generate_config.py b/truss/templates/triton/root/generate_config.py new file mode 100644 index 000000000..831a4efd7 --- /dev/null +++ b/truss/templates/triton/root/generate_config.py @@ -0,0 +1,215 @@ +import importlib +from pathlib import Path +from types import ModuleType +from typing import Any, List, Type, Union + +import yaml +from jinja2 import Environment, FileSystemLoader, Template +from pydantic import BaseModel + +# Copied from truss/templates/triton/model/utils/pydantic.py +FieldAnnotationType = Union[ + Type[int], Type[float], Type[str], Type[bool], Type[dict], List[Any] +] + +# Mapping from pydantic types to Triton types +TYPE_MAPPING = { + "str": "TYPE_STRING", + "int": "TYPE_INT32", + "float": "TYPE_FP32", + "bool": "TYPE_BOOL", + "dict": "TYPE_STRING", +} + + +class FieldDescriptor: + """Descriptor for fields in a Pydantic model.""" + + def __init__(self, name: str, triton_type: str, dims: List[int]): + self.name = name + self.triton_type = triton_type + self.dims = dims + + @classmethod + def from_pydantic_field(cls, field_name: str, field_info: Any) -> "FieldDescriptor": + """Create a FieldDescriptor from a Pydantic field.""" + return cls( + name=field_name, + triton_type=get_triton_type(field_info.annotation), + dims=get_triton_dims(field_info), + ) + + @classmethod + def from_pydantic_class( + cls, pydantic_class: Type[BaseModel] + ) -> List["FieldDescriptor"]: + """Create a list of FieldDescriptors from a Pydantic class.""" + return [ + cls.from_pydantic_field(field_name, field_info) + for field_name, field_info in pydantic_class.__fields__.items() # type: ignore + ] + + def to_dict(self) -> dict: + """Convert the FieldDescriptor to a dictionary.""" + return { + "name": self.name, + "dims": self.dims, + "triton_type": self.triton_type, + } + + +def read_template_from_fs(base_dir: Path, template_file_name: str) -> Template: + """Read a Jinja template from the filesystem.""" + template_loader = FileSystemLoader(str(base_dir)) + template_env = Environment(loader=template_loader) + return template_env.get_template(template_file_name) + + +def path_to_module(path: Path) -> str: + """Convert a filesystem path to a Python module string.""" + return str(path).replace("/", ".") + + +def is_conlist(pydantic_field_annotation: FieldAnnotationType) -> bool: + """Check if a Pydantic field annotation is a conlist.""" + return ( + hasattr(pydantic_field_annotation, "__origin__") + and pydantic_field_annotation.__origin__ == list + ) + + +def get_triton_type(pydantic_field: FieldAnnotationType) -> str: + """Get the corresponding Triton type for a Pydantic field.""" + if is_conlist(pydantic_field): + pydantic_field = pydantic_field.__args__[0] # type: ignore + + type_name = ( + pydantic_field.__name__ + if hasattr(pydantic_field, "__name__") + else str(pydantic_field) + ) + + if type_name not in TYPE_MAPPING: + raise TypeError(f"Unsupported type: {type_name}") + + return TYPE_MAPPING[type_name] + + +def get_triton_dims(pydantic_type: Type) -> List[int]: + """Get the dimensions for a Pydantic type.""" + if is_conlist(pydantic_type.annotation): + dims = next( + (m for m in pydantic_type.metadata if hasattr(m, "min_length")), None + ) + if dims and dims.min_length == dims.max_length: + return [-1, dims.min_length] + else: + return [-1, -1] + return [-1] + + +def validate_user_model_class(module: ModuleType, required_attributes: List[str]): + """ + Validates that a user model class has the required attributes. + + Args: + module (ModuleType): The module containing the user model class. + required_attributes (list): A list of required attributes. + + Raises: + AttributeError: If a required attribute is missing. + """ + for attribute in required_attributes: + if not hasattr(module, attribute): + raise AttributeError( + f"Truss model class is missing {attribute} attribute. For Triton, " + f"we require a Pydantic model class that corresponds to your model " + f"input and model output." + ) + + +def generate_config_pbtxt( + input_class: Type[BaseModel], + output_class: Type[BaseModel], + template_path: Path, + template_name: str = "config.pbtxt.jinja", + max_batch_size: int = 1, + num_replicas: int = 1, + dynamic_batch_delay_ms: int = 0, + is_gpu: bool = False, +) -> str: + """Generate the config.pbtxt file content.""" + config_params = { + "max_batch_size": max_batch_size, + "num_replicas": num_replicas, + "is_gpu": is_gpu, + } + template = read_template_from_fs(template_path, template_name) + input_cls_field_descriptors = FieldDescriptor.from_pydantic_class(input_class) + output_cls_field_descriptors = FieldDescriptor.from_pydantic_class(output_class) + inputs = [field.to_dict() for field in input_cls_field_descriptors] + outputs = [field.to_dict() for field in output_cls_field_descriptors] + if dynamic_batch_delay_ms > 0: + config_params["dynamic_batching_delay_microseconds"] = dynamic_batch_delay_ms + return template.render(inputs=inputs, outputs=outputs, **config_params) + + +def load_config(user_truss_path: Path) -> dict: + """Load the configuration from the YAML file.""" + with open(user_truss_path / "config.yaml", "r") as f: + return yaml.safe_load(f) + + +def extract_config_arguments(config: dict) -> dict: + """Extract configuration arguments from the loaded config.""" + build = config.get("build", {}) + build.pop("model_server", None) + return {k: v for d in build["arguments"] for k, v in d.items()} + + +def determine_gpu_usage(config: dict) -> bool: + """Determine if GPU is enabled based on the config.""" + return config.get("resources", {}).get("use_gpu", False) + + +def import_user_model(user_truss_path: Path, model_class_filename: str) -> ModuleType: + """Import the user model class.""" + model_module_name = str(Path(model_class_filename).with_suffix("")) + return importlib.import_module( + f"{path_to_module(user_truss_path)}.model.{model_module_name}" + ) + + +def write_config_to_file(model_repository_path: Path, config_pbtxt: str): + """Write the generated config.pbtxt to the model directory.""" + with open(model_repository_path / "config.pbtxt", "w") as f: + f.write(config_pbtxt) + + +def main(): + """Main function to generate the config.pbtxt.""" + model_repository_path = Path("model") + user_truss_path = model_repository_path / "1" / "truss" + + # Load and process configuration + config = load_config(user_truss_path) + config_arguments = extract_config_arguments(config) + config_arguments["is_gpu"] = determine_gpu_usage(config) + + # Import and validate user model + module = import_user_model(user_truss_path, config["model_class_filename"]) + input_cls_name = config.get("input_type_name", "Input") + output_cls_name = config.get("output_type_name", "Output") + validate_user_model_class(module, [input_cls_name, output_cls_name]) + + # Generate and write config.pbtxt + input_cls = getattr(module, input_cls_name) + output_cls = getattr(module, output_cls_name) + triton_config = generate_config_pbtxt( + input_cls, output_cls, Path("/app"), **config_arguments + ) + write_config_to_file(model_repository_path, triton_config) + + +if __name__ == "__main__": + main() diff --git a/truss/templates/triton/root/proxy.conf b/truss/templates/triton/root/proxy.conf new file mode 100644 index 000000000..660fe969e --- /dev/null +++ b/truss/templates/triton/root/proxy.conf @@ -0,0 +1,29 @@ + +server { + listen 8080; + + # Liveness + location / { + add_header Content-Type text/plain; + return 200 ""; + } + + # Readiness + location ~ ^/v1/models/model$ { + proxy_redirect off; + + rewrite ^/v1/models/model$ /v2/health/ready break; + + proxy_pass http://127.0.0.1:8000; + } + + # Predict + location ~ ^/v1/models/model:predict$ { + proxy_redirect off; + + rewrite ^/v1/models/model:predict$ /v2/models/model/infer break; + + proxy_pass http://127.0.0.1:8000; + } + +} diff --git a/truss/templates/triton/root/supervisord.conf b/truss/templates/triton/root/supervisord.conf new file mode 100644 index 000000000..12fd7d7aa --- /dev/null +++ b/truss/templates/triton/root/supervisord.conf @@ -0,0 +1,25 @@ +[supervisord] +nodaemon=true +logfile=/dev/null +logfile_maxbytes=0 + + +[program:triton-server] +command=tritonserver --model-repository=/app +startsecs=0 +startretries=1 +autostart=true +autorestart=true +stdout_logfile=/dev/fd/1 +stdout_logfile_maxbytes=0 +redirect_stderr=true + +[program:nginx] +command=nginx -g "daemon off;" +startsecs=0 +startretries=1 +autostart=true +autorestart=true +stdout_logfile=/dev/fd/1 +stdout_logfile_maxbytes=0 +redirect_stderr=true diff --git a/truss/truss_config.py b/truss/truss_config.py index 54a080cb1..cf4cc400a 100644 --- a/truss/truss_config.py +++ b/truss/truss_config.py @@ -149,6 +149,7 @@ class ModelServer(Enum): TrussServer = "TrussServer" TGI = "TGI" VLLM = "VLLM" + TRITON = "TRITON" @dataclass