diff --git a/docs/customization_guide/tritonfrontend.md b/docs/customization_guide/tritonfrontend.md index b46206f2f3..092515de37 100644 --- a/docs/customization_guide/tritonfrontend.md +++ b/docs/customization_guide/tritonfrontend.md @@ -25,7 +25,7 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. --> -### Triton Server (tritonfrontend) Bindings (Beta) +## Triton Server (tritonfrontend) Bindings (Beta) The `tritonfrontend` python package is a set of bindings to Triton's existing frontends implemented in C++. Currently, `tritonfrontend` supports starting up @@ -35,13 +35,20 @@ with Triton's Python In-Process API and [`tritonclient`](https://github.com/triton-inference-server/client/tree/main/src/python/library) extend the ability to use Triton's full feature set with a few lines of Python. -Let us walk through a simple example: -1. First we need to load the desired models and start the server with `tritonserver`. +### Example Workflow: + +1. Enter the triton container: +```bash +docker run -ti nvcr.io/nvidia/tritonserver:{YY.MM}-python-py3 +``` +Note: The tritonfrontend/tritonserver wheels have been shipped and installed by default in the container since 24.11 release. + +2. First we need to load the desired models and start the server with `tritonserver`. ```python import tritonserver # Constructing path to Model Repository -model_path = f"server/src/python/examples/example_model_repository" +model_path = "server/src/python/examples/example_model_repository" server_options = tritonserver.Options( server_id="ExampleServer", @@ -83,7 +90,7 @@ url = "localhost:8000" client = httpclient.InferenceServerClient(url=url) # Prepare input data -input_data = np.array([["Roger Roger"]], dtype=object) +input_data = np.array(["Roger Roger"], dtype=object) # Create input and output objects inputs = [httpclient.InferInput("INPUT0", input_data.shape, "BYTES")] @@ -139,12 +146,61 @@ server.stop() ``` With this workflow, you can avoid having to stop each service after client requests have terminated. +### Example with RestrictedFeatures: + In order to restrict access to certain endpoints(inference, metadata, model-repo, ...), RestrictedFeatures can be utilized. + Let us walk through an example of restricting inference: +1. Similar to the previous workflow, we start with getting the server up and running. + ```python +import tritonserver + +model_path = "server/src/python/examples/example_model_repository" + +server = tritonserver.Server(model_repostiory=model_path).start(wait_until_ready=True) + ``` + +2. Now, we can restrict inference and start the endpoints. +```python +from tritonfrontend import Feature, RestrictedFeatures, KServeHttp + +rf = RestrictedFeatures() +rf.create_feature_group("some-infer-key", "secret-infer-value", [Feature.INFERENCE]) + +http_options = KServeHttp.Options(restricted_features=rf) +http_service = KServeHttp(server, http_options) +http_service.start() +``` + +3. Finally, let us try sending a inference request to these endpoints: +```python +import tritonclient.http as httpclient -## Known Issues +model_name = "identity" +url = "localhost:8000" +valid_credentials = {"some-infer-key": "secret-infer-value"} +with httpclient.InferenceServerClient(url=url) as client: + input_data = np.array(["Roger Roger"], dtype=object) + inputs = [httpclient.InferInput("INPUT0", input_data.shape, "BYTES")] + inputs[0].set_data_from_numpy(input_data) + results = client.infer(model_name, inputs=inputs, headers=valid_credentials) + output_data = results.as_numpy("OUTPUT0") + print("[INFERENCE RESULTS]") + print("Output data:", output_data) +``` +Note: If you remove the `header=valid_credentials` argument from `client.infer()`, +then you can see a failed inference request that looks something like that: +``` +... +tritonclient.utils.InferenceServerException: [403] This API is restricted, +expecting header 'some-infer-key' +``` +For more information on restrictedfeatures take a look at the following supporting docs: +- [limit endpoint access docs](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/inference_protocols.md#limit-endpoint-access-beta) +- [restricted features implementation](https://github.com/triton-inference-server/server/blob/main/src/python/tritonfrontend/_api/_restricted_features.py) +### Known Issues - The following features are not currently supported when launching the Triton frontend services through the python bindings: - [Tracing](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/trace.md) - [Shared Memory](https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_shared_memory.md) - - [Restricted Protocols](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/inference_protocols.md#limit-endpoint-access-beta) - VertexAI - Sagemaker -- After a running server has been stopped, if the client sends an inference request, a Segmentation Fault will occur. \ No newline at end of file +- After a running server has been stopped, if the client sends an inference request, a Segmentation Fault will occur. +- Using tritonclient.grpc and tritonserver in the same process may cause crash/abort due to lack of `fork()` support in [`cygrpc`](https://github.com/grpc/grpc/blob/master/doc/fork_support.md) \ No newline at end of file diff --git a/qa/L0_python_api/test.sh b/qa/L0_python_api/test.sh index 0d87d16771..c0671b13d3 100755 --- a/qa/L0_python_api/test.sh +++ b/qa/L0_python_api/test.sh @@ -51,7 +51,14 @@ fi FRONTEND_TEST_LOG="./python_kserve.log" -python -m pytest --junitxml=test_kserve.xml test_kserve.py > $FRONTEND_TEST_LOG 2>&1 +# TODO: [DLIS-7735] Run tritonclient.grpc as separate process +# Currently, running tritonclient.grpc with tritonserver in the same process, +# it will non-deterministically abort/crash without being able to be caught by pytest. +# This is because fork() is called by tritonserver on model load, +# which attempts to fork the imported libraries and their internal states, +# and cygrpc (dependency of tritonclient.grpc) does not officially support fork(). +# Reference: https://github.com/grpc/grpc/blob/master/doc/fork_support.md +python -m pytest --junitxml=test_kserve.xml test_kserve.py -k "not KServeGrpc" > $FRONTEND_TEST_LOG 2>&1 if [ $? -ne 0 ]; then cat $FRONTEND_TEST_LOG echo -e "\n***\n*** Test Failed\n***" diff --git a/qa/L0_python_api/test_kserve.py b/qa/L0_python_api/test_kserve.py index cefd085cdb..e226714f9b 100644 --- a/qa/L0_python_api/test_kserve.py +++ b/qa/L0_python_api/test_kserve.py @@ -30,11 +30,20 @@ import numpy as np import pytest import testing_utils as utils -import tritonclient.grpc as grpcclient + +# TODO: [DLIS-7735] Run tritonclient.grpc as a separate process +# import tritonclient.grpc as grpcclient import tritonclient.http as httpclient import tritonserver from tritonclient.utils import InferenceServerException -from tritonfrontend import KServeGrpc, KServeHttp, Metrics +from tritonfrontend import ( + Feature, + FeatureGroup, + KServeGrpc, + KServeHttp, + Metrics, + RestrictedFeatures, +) class TestHttpOptions: @@ -108,8 +117,99 @@ def test_wrong_http_parameters(self): Metrics.Options(thread_count="ten") +class TestRestrictedFeatureOptions: + def test_correct_parameters(self): + # Directly test feature groups + correct_feature_group = FeatureGroup( + key="health-key", + value="health-val", + features=[Feature.HEALTH, Feature.METADATA], + ) + + rf = RestrictedFeatures(groups=[correct_feature_group]) + + rf.create_feature_group( + key="infer-key", value="infer-val", features=Feature.INFERENCE + ) + + assert all( + rf.has_feature(feature) + for feature in [Feature.HEALTH, Feature.METADATA, Feature.INFERENCE] + ) + + feature_list = rf.get_feature_groups() + expected_list = [ + FeatureGroup("health-key", "health-val", Feature.HEALTH), + FeatureGroup("health-key", "health-val", Feature.METADATA), + FeatureGroup("infer-key", "infer-val", Feature.INFERENCE), + ] + + # 3 groups: 1 for each Feature. + assert len(feature_list) == 3 + + # Converting FeatureGroup->str and sorting to not have to account for order. + feature_str_repr = lambda feature_groups: sorted( + repr(group) for group in feature_groups + ) + + assert feature_str_repr(expected_list) == feature_str_repr(feature_list) + + # Updating Feature.METADATA with new (key, value) pair + rf.update_feature_group(Feature.METADATA, "metadata-key", "metadata-val") + expected_list[1] = FeatureGroup( + "metadata-key", "metadata-val", Feature.METADATA + ) + + assert len(feature_list) == 3 + assert feature_str_repr(expected_list) == feature_str_repr(feature_list) + + rf.remove_features(Feature.INFERENCE) + assert len(rf.get_feature_groups()) == 2 + assert not rf.has_feature(Feature.INFERENCE) + + rf.remove_features(Feature.HEALTH) + assert len(rf.get_feature_groups()) == 1 + assert not rf.has_feature(Feature.HEALTH) and rf.has_feature(Feature.METADATA) + + def test_wrong_rf_parameters(self): + rf = RestrictedFeatures() + # Features List needs to be an element from tritonfrontend.Feature + with pytest.raises(tritonserver.InvalidArgumentError): + rf.create_feature_group(key="", value="", features=["health"]) + + # key and value need to be of type string + with pytest.raises(Exception): + rf.create_feature_group( + key=42, value="Secret to the Universe", features=[Feature.HEALTH] + ) + with pytest.raises(Exception): + rf.create_feature_group(key="", value=123, features=[Feature.HEALTH]) + + # Test collision of Features among individual Feature Groups + with pytest.raises( + tritonserver.InvalidArgumentError, + match="A given feature can only belong to one " + "group. Feature.HEALTH already belongs to an existing group.", + ): + feature_group = FeatureGroup( + key="key", value="val", features=[Feature.METADATA, Feature.HEALTH] + ) + + rf = RestrictedFeatures(groups=[feature_group]) + rf.create_feature_group(key="key2", value="val", features=[Feature.HEALTH]) + + with pytest.raises( + tritonserver.InvalidArgumentError, + match="not present in any of the FeatureGroups for " + "the RestrictedFeatures object and therefore cannot be removed.", + ): + rf = RestrictedFeatures() + rf.remove_features(Feature.HEALTH) + + HTTP_ARGS = (KServeHttp, httpclient, "localhost:8000") # Default HTTP args -GRPC_ARGS = (KServeGrpc, grpcclient, "localhost:8001") # Default GRPC args +# TODO: [DLIS-7735] Run tritonclient.grpc as separate process +GRPC_ARGS = (KServeGrpc, None, "localhost:8001") # Default GRPC args METRICS_ARGS = (Metrics, "localhost:8002") # Default Metrics args @@ -251,6 +351,7 @@ def test_http_req_during_shutdown(self, frontend, client_type, url): def test_grpc_req_during_shutdown(self, frontend, client_type, url): server = utils.setup_server() grpc_service = utils.setup_service(server, frontend) + # TODO: [DLIS-7735] Run tritonclient.grpc as a separate process grpc_client = grpcclient.InferenceServerClient(url=url) user_data = [] @@ -302,6 +403,34 @@ def callback(user_data, result, error): utils.teardown_client(grpc_client) utils.teardown_server(server) + # KNOWN ISSUE: CAUSES SEGFAULT + # Created [DLIS-7231] to address at future date + # Once the server has been stopped, the underlying TRITONSERVER_Server instance + # is deleted. However, the frontend does not know the server instance + # is no longer valid. + # def test_inference_after_server_stop(self): + # server = utils.setup_server() + # http_service = utils.setup_service(server, KServeHttp) + # http_client = setup_client(httpclient, url="localhost:8000") + + # teardown_server(server) # Server has been stopped + + # model_name = "identity" + # input_data = np.array([["testing"]], dtype=object) + # # Create input and output objects + # inputs = [httpclient.InferInput("INPUT0", input_data.shape, "BYTES")] + # outputs = [httpclient.InferRequestedOutput("OUTPUT0")] + + # # Set the data for the input tensor + # inputs[0].set_data_from_numpy(input_data) + + # results = http_client.infer(model_name, inputs=inputs, outputs=outputs) + + # utils.teardown_client(http_client) + # utils.teardown_service(http_service) + + +class TestMetrics: @pytest.mark.parametrize("frontend, url", [METRICS_ARGS]) def test_metrics_default_port(self, frontend, url): server = utils.setup_server() @@ -330,10 +459,10 @@ def test_metrics_custom_port(self, frontend, port=8005): @pytest.mark.parametrize("frontend, url", [METRICS_ARGS]) def test_metrics_update(self, frontend, url): - # Setup Server, KServeGrpc, Metrics + # Setup Server, KServeHttp, Metrics server = utils.setup_server() - grpc_service = utils.setup_service( - server, KServeGrpc + http_service = utils.setup_service( + server, KServeHttp ) # Needed to send inference request metrics_service = utils.setup_service(server, frontend) @@ -344,7 +473,7 @@ def test_metrics_update(self, frontend, url): assert before_status_code == 200 and before_inference_count == 0 # Send 1 Inference Request with send_and_test_inference() - assert utils.send_and_test_inference_identity(GRPC_ARGS[1], GRPC_ARGS[2]) + assert utils.send_and_test_inference_identity(HTTP_ARGS[1], HTTP_ARGS[2]) # Get Metrics and verify inference count == 1 after inference after_status_code, after_inference_count = utils.get_metrics( @@ -352,33 +481,110 @@ def test_metrics_update(self, frontend, url): ) assert after_status_code == 200 and after_inference_count == 1 - # Teardown Metrics, GrpcService, Server - utils.teardown_service(grpc_service) + # Teardown Metrics, HttpService, Server + utils.teardown_service(http_service) utils.teardown_service(metrics_service) utils.teardown_server(server) - # KNOWN ISSUE: CAUSES SEGFAULT - # Created [DLIS-7231] to address at future date - # Once the server has been stopped, the underlying TRITONSERVER_Server instance - # is deleted. However, the frontend does not know the server instance - # is no longer valid. - # def test_inference_after_server_stop(self): - # server = utils.setup_server() - # http_service = utils.setup_service(server, KServeHttp) - # http_client = setup_client(httpclient, url="localhost:8000") - # teardown_server(server) # Server has been stopped +class TestRestrictedFeatures: + @pytest.mark.parametrize( + "frontend, client_type, url, key_prefix", + [HTTP_ARGS + ("",), GRPC_ARGS + ("triton-grpc-protocol-",)], + ) + def test_restrict_inference(self, frontend, client_type, url, key_prefix): + server = utils.setup_server() - # model_name = "identity" - # input_data = np.array([["testing"]], dtype=object) - # # Create input and output objects - # inputs = [httpclient.InferInput("INPUT0", input_data.shape, "BYTES")] - # outputs = [httpclient.InferRequestedOutput("OUTPUT0")] + # Specifying restricted features that restricts inference. + infer_key, infer_value = "infer-key", "infer-value" - # # Set the data for the input tensor - # inputs[0].set_data_from_numpy(input_data) + rf = RestrictedFeatures() + rf.create_feature_group( + key=infer_key, + value=infer_value, + features=[Feature.INFERENCE], + ) - # results = http_client.infer(model_name, inputs=inputs, outputs=outputs) + options = frontend.Options(restricted_features=rf) + service = utils.setup_service(server, frontend, options=options) + + # Valid headers sent with inference request + headers = {key_prefix + "infer-key": "infer-value"} + assert utils.send_and_test_inference_identity(client_type, url, headers) + + # Combinations of Invalid (or no) headers sent with inference request + invalid_key_value = {key_prefix + "fake-key": "fake-value"} + invalid_value = {key_prefix + infer_key: "fake-value"} + error_msg = f"expecting header '{key_prefix}infer-key'" + + for header, err_msg in [ + (invalid_key_value, error_msg), + (invalid_value, error_msg), + (None, error_msg), + ]: + with pytest.raises( + InferenceServerException, + match=err_msg, + ): + utils.send_and_test_inference_identity(client_type, url, header) - # utils.teardown_client(http_client) - # utils.teardown_service(http_service) + utils.teardown_service(service) + utils.teardown_server(server) + + @pytest.mark.parametrize( + "frontend, client_type, url, key_prefix", + [HTTP_ARGS + ("",), GRPC_ARGS + ("triton-grpc-protocol-",)], + ) + def test_multiple_groups(self, frontend, client_type, url, key_prefix): + server = utils.setup_server() + + # Credentials used to restrict/access Triton Features. + model_repo_key, model_repo_val = "repo-key", "repo-value" + infer_key, infer_val = "infer-key", "infer-value" + + # Specifying restricted feature that restricts multiple groups + rf = RestrictedFeatures() + rf.create_feature_group( + key=model_repo_key, + value=model_repo_val, + features=[Feature.MODEL_REPOSITORY], + ) + rf.create_feature_group( + key=infer_key, value=infer_val, features=[Feature.INFERENCE] + ) + + options = frontend.Options(restricted_features=rf) + service = utils.setup_service(server, frontend, options=options) + client = utils.setup_client(client_type, url=url) + + # Testing if Feature.MODEL_REPOSITORY is restricted correctly + model_config_header = {key_prefix + model_repo_key: model_repo_val} + + model_repo_index = client.get_model_repository_index( + headers=model_config_header + ) + model_repo_contents = str(model_repo_index) + assert "delayed_identity" in model_repo_contents + + with pytest.raises( + InferenceServerException, + match=f"expecting header '{key_prefix}{model_repo_key}'", + ): + client.get_model_repository_index(headers={"fake-key": "fake-value"}) + + # Testing if Feature.INFERENCE is restricted correctly + infer_header = {key_prefix + infer_key: infer_val} + + assert utils.send_and_test_inference_identity(client_type, url, infer_header) + + with pytest.raises( + InferenceServerException, + match=f"expecting header '{key_prefix}{infer_key}'", + ): + utils.send_and_test_inference_identity( + client_type, url, {"fake-key": "fake-value"} + ) + + utils.teardown_client(client) + utils.teardown_service(service) + utils.teardown_server(server) diff --git a/qa/L0_python_api/testing_utils.py b/qa/L0_python_api/testing_utils.py index d80c4a233a..ca739ef783 100644 --- a/qa/L0_python_api/testing_utils.py +++ b/qa/L0_python_api/testing_utils.py @@ -99,6 +99,7 @@ def send_and_test_inference_identity( "tritonclient.grpc.InferenceServerClient", ], url: str, + headers=None, ) -> bool: """ Sends an inference request to the model at test_model_repository/identity @@ -115,7 +116,9 @@ def send_and_test_inference_identity( inputs[0].set_data_from_numpy(input_data) # Perform inference request - results = client.infer(model_name=model_name, inputs=inputs, outputs=outputs) + results = client.infer( + model_name=model_name, inputs=inputs, outputs=outputs, headers=headers + ) output_data = results.as_numpy("OUTPUT0") # Gather output data diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index 74ec443ae6..3dc38fb8c5 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -2445,13 +2445,15 @@ Server::Create( { Options grpc_options; - RETURN_IF_ERR(GetOptions(grpc_options, options)); + RETURN_IF_ERR(GetOptions(grpc_options, options, restricted_features)); return Create(server, trace_manager, shm_manager, grpc_options, service); } TRITONSERVER_Error* -Server::GetOptions(Options& options, UnorderedMapType& options_map) +Server::GetOptions( + Options& options, UnorderedMapType& options_map, + const RestrictedFeatures& restricted_features) { SocketOptions socket_selection; SslOptions ssl_selection; @@ -2475,6 +2477,8 @@ Server::GetOptions(Options& options, UnorderedMapType& options_map) RETURN_IF_ERR(GetValue( options_map, "forward_header_pattern", &options.forward_header_pattern_)); + options.restricted_protocols_ = restricted_features; + return nullptr; } diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index 89d8dc7388..15c1051a18 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -1,4 +1,4 @@ -// Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -128,7 +128,9 @@ class Server { KeepAliveOptions& options, UnorderedMapType& options_map); static TRITONSERVER_Error* GetOptions( - Options& options, UnorderedMapType& options_map); + Options& options, UnorderedMapType& options_map, + const RestrictedFeatures& restricted_features); + std::shared_ptr tritonserver_; TraceManager* trace_manager_; diff --git a/src/python/examples/example.py b/src/python/examples/example.py index 2d2ca78920..e0bfa917c9 100644 --- a/src/python/examples/example.py +++ b/src/python/examples/example.py @@ -60,7 +60,7 @@ def main(): client = httpclient.InferenceServerClient(url=url) # Prepare input data - input_data = np.array([["Roger Roger"]], dtype=object) + input_data = np.array(["Roger Roger"], dtype=object) # Create input and output objects inputs = [httpclient.InferInput("INPUT0", input_data.shape, "BYTES")] diff --git a/src/python/examples/example_model_repository/identity/1/model.py b/src/python/examples/example_model_repository/identity/1/model.py new file mode 100644 index 0000000000..6b25af3e45 --- /dev/null +++ b/src/python/examples/example_model_repository/identity/1/model.py @@ -0,0 +1,41 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + """An identity model that returns the input tensor as output.""" + + def initialize(self, args): + pass + + def execute(self, requests): + responses = [] + for request in requests: + in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") + out_tensor_0 = pb_utils.Tensor("OUTPUT0", in_0.as_numpy()) + responses.append(pb_utils.InferenceResponse([out_tensor_0])) + return responses diff --git a/src/python/examples/example_model_repository/identity/1/model.savedmodel/saved_model.pb b/src/python/examples/example_model_repository/identity/1/model.savedmodel/saved_model.pb deleted file mode 100755 index 63f78fecb4..0000000000 Binary files a/src/python/examples/example_model_repository/identity/1/model.savedmodel/saved_model.pb and /dev/null differ diff --git a/src/python/examples/example_model_repository/identity/config.pbtxt b/src/python/examples/example_model_repository/identity/config.pbtxt index ae83e47556..2a61d3edf2 100644 --- a/src/python/examples/example_model_repository/identity/config.pbtxt +++ b/src/python/examples/example_model_repository/identity/config.pbtxt @@ -1,4 +1,4 @@ -# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -25,20 +25,26 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. name: "identity" -platform: "tensorflow_savedmodel" -max_batch_size: 8 +backend: "python" +max_batch_size: 0 input [ { name: "INPUT0" data_type: TYPE_STRING - dims: [ -1 ] + dims: [ 1 ] } ] output [ { name: "OUTPUT0" data_type: TYPE_STRING - dims: [ -1 ] + dims: [ 1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU } ] diff --git a/src/python/tritonfrontend/CMakeLists.txt b/src/python/tritonfrontend/CMakeLists.txt index 8d9b997ce5..7ad2ebed68 100644 --- a/src/python/tritonfrontend/CMakeLists.txt +++ b/src/python/tritonfrontend/CMakeLists.txt @@ -68,6 +68,10 @@ set( set(PY_BINDING_DEPENDENCY_LIBS b64) # Dependency from common.h +list(APPEND PY_BINDING_DEPENDENCY_LIBS + triton-common-json + ) + # Conditional Linking Based on Flags if(${TRITON_ENABLE_HTTP}) list(APPEND PY_BINDING_DEPENDENCY_LIBS diff --git a/src/python/tritonfrontend/__init__.py b/src/python/tritonfrontend/__init__.py index 60ea2b7050..1e971c4bba 100644 --- a/src/python/tritonfrontend/__init__.py +++ b/src/python/tritonfrontend/__init__.py @@ -29,6 +29,8 @@ import builtins from importlib.metadata import PackageNotFoundError, version +from tritonfrontend._api import Feature, FeatureGroup, RestrictedFeatures + try: from tritonfrontend._api import KServeHttp except ImportError: diff --git a/src/python/tritonfrontend/_api/__init__.py b/src/python/tritonfrontend/_api/__init__.py index 0c73f92df4..fc98809e33 100644 --- a/src/python/tritonfrontend/_api/__init__.py +++ b/src/python/tritonfrontend/_api/__init__.py @@ -24,6 +24,9 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from ._restricted_features import Feature, FeatureGroup, RestrictedFeatures + try: from ._kservehttp import KServeHttp except ImportError: diff --git a/src/python/tritonfrontend/_api/_error_mapping.py b/src/python/tritonfrontend/_api/_error_mapping.py index 8bd1764c00..8e04187f69 100644 --- a/src/python/tritonfrontend/_api/_error_mapping.py +++ b/src/python/tritonfrontend/_api/_error_mapping.py @@ -25,6 +25,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import sys +from functools import wraps import tritonserver from tritonfrontend._c.tritonfrontend_bindings import ( @@ -52,9 +53,10 @@ def handle_triton_error(func): + @wraps(func) # Retains the original function's signature. def error_handling_wrapper(*args, **kwargs): try: - func(*args, **kwargs) + return func(*args, **kwargs) except TritonError: exc_type, exc_value, _ = sys.exc_info() # raise ... from None masks the tritonfrontend Error from being added in traceback diff --git a/src/python/tritonfrontend/_api/_kservegrpc.py b/src/python/tritonfrontend/_api/_kservegrpc.py index efa706a77a..ea32012ce9 100644 --- a/src/python/tritonfrontend/_api/_kservegrpc.py +++ b/src/python/tritonfrontend/_api/_kservegrpc.py @@ -28,9 +28,10 @@ from typing import Union import tritonserver -from pydantic import Field +from pydantic import ConfigDict, Field from pydantic.dataclasses import dataclass from tritonfrontend._api._error_mapping import handle_triton_error +from tritonfrontend._api._restricted_features import RestrictedFeatures from tritonfrontend._c.tritonfrontend_bindings import ( InvalidArgumentError, TritonFrontendGrpc, @@ -52,7 +53,7 @@ class KServeGrpc: ) # triton::server::grpc::Options - @dataclass + @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class Options: # triton::server::grpc::SocketOptions address: str = "0.0.0.0" @@ -81,13 +82,19 @@ class Options: ] = Grpc_compression_level.NONE infer_allocation_pool_size: int = Field(8, ge=0) forward_header_pattern: str = "" - # DLIS-7215: Add restricted protocol support - # restricted_protocols: str = "" + restricted_features: RestrictedFeatures = RestrictedFeatures() + @handle_triton_error def __post_init__(self): if isinstance(self.infer_compression_level, Grpc_compression_level): self.infer_compression_level = self.infer_compression_level.value + if not isinstance(self.restricted_features, RestrictedFeatures): + raise InvalidArgumentError( + "restricted_features needs an instance of RestrictedFeatures." + ) + self.restricted_features = repr(self.restricted_features) + @handle_triton_error def __init__(self, server: tritonserver, options: "KServeGrpc.Options" = None): server_ptr = server._ptr() # TRITONSERVER_Server pointer diff --git a/src/python/tritonfrontend/_api/_kservehttp.py b/src/python/tritonfrontend/_api/_kservehttp.py index 6002a3f3f5..dd7e148b9d 100644 --- a/src/python/tritonfrontend/_api/_kservehttp.py +++ b/src/python/tritonfrontend/_api/_kservehttp.py @@ -28,9 +28,10 @@ from typing import Union import tritonserver -from pydantic import Field +from pydantic import ConfigDict, Field from pydantic.dataclasses import dataclass from tritonfrontend._api._error_mapping import handle_triton_error +from tritonfrontend._api._restricted_features import RestrictedFeatures from tritonfrontend._c.tritonfrontend_bindings import ( InvalidArgumentError, TritonFrontendHttp, @@ -38,15 +39,22 @@ class KServeHttp: - @dataclass + @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class Options: address: str = "0.0.0.0" port: int = Field(8000, ge=0, le=65535) reuse_port: bool = False thread_count: int = Field(8, gt=0) header_forward_pattern: str = "" - # DLIS-7215: Add restricted protocol support - # restricted_protocols: list + restricted_features: RestrictedFeatures = RestrictedFeatures() + + @handle_triton_error + def __post_init__(self): + if not isinstance(self.restricted_features, RestrictedFeatures): + raise InvalidArgumentError( + "restricted_features needs an instance of RestrictedFeatures." + ) + self.restricted_features = repr(self.restricted_features) @handle_triton_error def __init__(self, server: tritonserver, options: "KServeHttp.Options" = None): diff --git a/src/python/tritonfrontend/_api/_restricted_features.py b/src/python/tritonfrontend/_api/_restricted_features.py new file mode 100644 index 0000000000..cb96bb7a38 --- /dev/null +++ b/src/python/tritonfrontend/_api/_restricted_features.py @@ -0,0 +1,292 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +from copy import deepcopy +from enum import Enum +from typing import List, Union + +from pydantic import field_validator +from pydantic.dataclasses import dataclass +from tritonfrontend._api._error_mapping import handle_triton_error +from tritonfrontend._c.tritonfrontend_bindings import ( + InvalidArgumentError, + NotFoundError, +) + + +class Feature(Enum): + """ + List of Features that are provided by KServeHttp and KServeGrpc Endpoints for the Server. + 1-to-1 copy of RestrictedCategory Enum from https://github.com/triton-inference-server/server/blob/main/src/restricted_features.h + """ + + HEALTH = "health" + METADATA = "metadata" + INFERENCE = "inference" + SHM_MEMORY = "shared-memory" + MODEL_CONFIG = "model-config" + MODEL_REPOSITORY = "model-repository" + STATISTICS = "statistics" + TRACE = "trace" + LOGGING = "logging" + + +@dataclass(frozen=True) +class FeatureGroup: + """ + Stores instances of (key, value, features) and performs type validation on instance. + Used by the RestrictedFeatures Class. + + Example: + >>> from tritonfrontend import Feature, FeatureGroup + >>> infer_group = FeatureGroup("infer-key", "infer-value", Feature.INFERENCE) + >>> info_group = FeatureGroup("admin-key", "admin-value", [Feature.HEALTH, Feature.METADATA]) + >>> health_group = FeatureGroup("key", "value", ["health"]) # Will Error Out + Invalid features found: ['health'] ... Valid options are: ['Feature.HEALTH', + 'Feature.METADATA', 'Feature.INFERENCE', 'Feature.SHM_MEMORY', 'Feature.MODEL_CONFIG', + 'Feature.MODEL_REPOSITORY', 'Feature.STATISTICS', 'Feature.TRACE', 'Feature.LOGGING'] + """ + + key: str + value: str + features: Union[List[Feature], Feature] + + @field_validator("features", mode="before") + @handle_triton_error + def validate_features(features: List[Feature] | Feature) -> List[Feature]: + if isinstance(features, Feature): + features = [features] + + if not isinstance(features, list): + raise InvalidArgumentError( + "FeatureGroup.feature needs to be of type Feature or List[Feature]" + ) + + invalid_features = [item for item in features if not isinstance(item, Feature)] + if invalid_features: + raise InvalidArgumentError( + f"Invalid features found: {invalid_features}. " + "Each item in 'features' should be an instance of the tritonfrontend.Feature. " + f"Valid options are: {[str(p) for p in Feature]}" + ) + return features + + +class RestrictedFeatures: + """ + Using `RestrictedFeatures` users can restrict access to certain features provided by + the `KServeHttp` and `KServeGrpc` frontends. In order to use a restricted feature, + the key-value pair ({`key`:`value`}) needs to be as a header with the request to the endpoint. + Note: For the `KServeGrpc` endpoint, the header key needs a prefix of `triton-grpc-protocol-` + when sending a request. + + Internally, the `RestrictedFeatures` class: + - Stores collections of FeatureGroup instances + - Maintains FeatureGroup's down into one Feature + - Checks for collisions for `Feature` instances among groups. + - Serialize the data into a JSON string. + + Example: + >>> from tritonfrontend import Feature, FeatureGroup, RestrictedFeatures + >>> admin_group = FeatureGroup(key="admin-key", value="admin-value", features=[Feature.HEALTH, Feature.METADATA]) + >>> infer_group = FeatureGroup("infer-key", "infer-value", [Feature.INFERENCE]) + >>> rf = RestrictedFeatures([admin_group, infer_group]) + >>> rf.create_feature_group("trace-key", "trace-value", [Feature.TRACE]) + >>> rf + [ + {"key": "admin-key", "value": "admin-value", "features": ["health"]}, + {"key": "admin-key", "value": "admin-value", "features": ["metadata"]} + {"key": "infer-key", "value": "infer-value", "features": ["inference"]}, + {"key": "trace-key", "value": "trace-value", "features": ["trace"]} + ] + """ + + def __init__(self, groups: List[FeatureGroup] = []): + self.feature_groups = [] # Stores FeatureGroup Instances + self.features_restricted = set() # Used for collision detection between groups + + for feat_group in groups: + self.add_feature_group(feat_group) + + @handle_triton_error + def add_feature_group(self, group: FeatureGroup) -> None: + """ + Adds FeatureGroup to RestrictedFeatures object. + If n Features are in FeatureGroup, it is broken into n FeatureGroup instances + allowing for future granular edits. + + Example: + >>> from tritonfrontend import Feature, FeatureGroup, RestrictedFeatures + >>> health_group = FeatureGroup("health-key", "health-value", [Feature.HEALTH]) + >>> rf = RestrictedFeatures() + >>> rf.add_feature_group(health_group) + >>> rf + [{"key": "health-key", "value": "health-value", "features": ["health"]}] + """ + for feat in group.features: + if self.has_feature(feat): + raise InvalidArgumentError( + "A given feature can only belong to one group. " + f"{str(feat)} already belongs to an existing group." + ) + + new_group = FeatureGroup(group.key, group.value, feat) + self.features_restricted.add(feat) + self.feature_groups.append(new_group) + + def create_feature_group( + self, key: str, value: str, features: List[Feature] | Feature + ) -> None: + """ + Factory method used to generate FeatureGroup instances and append them + to the `RestrictedFeatures` object that invoked this function. + + Example: + >>> from tritonfrontend import RestrictedFeatures, Feature + >>> rf = RestrictedFeatures() + >>> rf.create_feature_group("infer-key", "infer-value", Feature.INFERENCE) + >>> rf.create_feature_group("meta-key", "meta-value", [Feature.METADATA, Feature.HEALTH]) + >>> rf + [ + {"key": "infer-key", "value": "infer-value", "features": ["inference"]}, + {"key": "meta-key", "value": "meta-value", "features": ["metadata"]}, + {"key": "meta-key", "value": "meta-value", "features": ["health"]}, + ] + """ + group = FeatureGroup(key, value, features) + self.add_feature_group(group) + + def has_feature(self, feature: Feature) -> bool: + """ + Checks if feature belongs to any of the groups + Example: + >>> from tritonfrontend import RestrictedFeatures, Feature + >>> rf = RestrictedFeatures() + >>> rf.create_feature_group("infer-key", "infer-value", [Feature.INFERENCE]) + >>> rf + [{"key": "infer-key", "value": "infer-value", "features": ["inference"]}] + >>> rf.has_feature(Feature.INFERENCE) + True + >>> rf.has_feature(Feature.TRACE) + False + """ + return feature in self.features_restricted + + @handle_triton_error + def update_feature_group(self, feature: Feature, key: str, value: str) -> None: + """ + Updates the key and value used to restrict a Feature + Example: + >>> from tritonfrontend import RestrictedFeatures, Feature + >>> rf = RestrictedFeatures() + >>> rf.create_feature_group("meta-key", "meta-value", [Feature.METADATA, Feature.HEALTH]) + >>> rf.update_feature_group(Feature.HEALTH, "health-key", "health-value") + >>> rf + [ + {"key": "health-key", "value": "health-value", "features": ["health"]}, + {"key": "meta-key", "value": "meta-value", "features": ["metadata"]} + ] + """ + if not self.has_feature(feature): + raise NotFoundError(f"{str(feature)} not being restricted.") + + for idx, group in enumerate(self.feature_groups): + if feature in group.features: + self.feature_groups[idx] = FeatureGroup(key, value, feature) + break + + @handle_triton_error + def remove_features(self, features: List[Feature] | Feature) -> None: + """ + Will remove FeatureGroups that contain the features specified. + Example: + >>> from tritonfrontend import RestrictedFeatures, Feature + >>> admin_group = FeatureGroup(key="admin-key", value="admin-value", features=[Feature.HEALTH, Feature.METADATA]) + >>> infer_group = FeatureGroup("infer-key", "infer-value", [Feature.INFERENCE]) + >>> mem_group = FeatureGroup("mem-key", "mem-value", [Feature.SHM_MEMORY]) + >>> rf = RestrictedFeatures([admin_group, infer_group]) + >>> rf.remove_features([Feature.HEALTH, Feature.SHM_MEMORY]) + >>> rf + [{"key": "admin-key", "value": "admin-value", "features": ["metadata"]}] + [{"key": "infer-key", "value": "infer-value", "features": ["inference"]}] + """ + if isinstance(features, Feature): + features = [features] + + not_present = [feat for feat in features if not self.has_feature(feat)] + if not_present: + raise InvalidArgumentError( + f"{not_present} is not present in any of the FeatureGroups for " + "the RestrictedFeatures object and therefore cannot be removed." + ) + + feature_set = set(features) + target_groups = [ + group + for group in self.feature_groups + if feature_set.intersection(group.features) + ] + for group in target_groups: + self.feature_groups.remove(group) + + for feat in group.features: + self.features_restricted.discard(feat) + + def get_feature_groups(self) -> List[FeatureGroup]: + """ + Returns a list of features groups. + """ + return self.feature_groups + + def _gather_restricted_data(self) -> dict: + """ + Represents `RestrictedFeatures` Instance as a dictionary. + Additionally, converts `Feature` instances to str equivalent. + """ + # Dataclass_Instance.__dict__ provides shallow copy, so need a deep copy IF modifying + rfeat_data = [ + deepcopy(feat_group.__dict__) for feat_group in self.feature_groups + ] + + for idx in range(len(rfeat_data)): + rfeat_data[idx]["features"] = [ + feat.value for feat in rfeat_data[idx]["features"] + ] + + return rfeat_data + + def __str__(self) -> str: + """ + A function to retrieve user-friendly string to view object contents. + """ + return json.dumps(self._gather_restricted_data(), indent=2) + + def __repr__(self) -> str: + """ + A function to retrieve representation that has not been formatted. + """ + return json.dumps(self._gather_restricted_data()) diff --git a/src/python/tritonfrontend/_c/tritonfrontend.h b/src/python/tritonfrontend/_c/tritonfrontend.h index ca215c15c8..7f0bf995aa 100644 --- a/src/python/tritonfrontend/_c/tritonfrontend.h +++ b/src/python/tritonfrontend/_c/tritonfrontend.h @@ -30,11 +30,22 @@ #include #include + +#ifdef TRITON_ENABLE_GRPC +#include "../../../grpc/grpc_server.h" +#endif + + +#if defined(TRITON_ENABLE_HTTP) || defined(TRITON_ENABLE_METRICS) +#include "../../../http_server.h" +#endif + + #include "../../../common.h" #include "../../../restricted_features.h" #include "../../../shared_memory_manager.h" -#include "../../../tracer.h" #include "triton/common/logging.h" +#include "triton/common/triton_json.h" #include "triton/core/tritonserver.h" @@ -115,6 +126,7 @@ class TritonFrontend { reinterpret_cast(server_mem_addr); server_.reset(server_ptr, EmptyDeleter); + TritonFrontend::_populate_restricted_features(data, restricted_features); #ifdef TRITON_ENABLE_HTTP if constexpr (std::is_same_v) { @@ -153,6 +165,64 @@ class TritonFrontend { // will cause a double-free when the core bindings attempt to // delete the TRITONSERVER_Server instance. static void EmptyDeleter(TRITONSERVER_Server* obj){}; -}; + static void _populate_restricted_features( + UnorderedMapType& data, RestrictedFeatures& rest_features) + { + std::string map_key = + "restricted_features"; // Name of option in UnorderedMap + std::string key_prefix; // Prefix for header key + +#if !defined(TRITON_ENABLE_HTTP) && !defined(TRITON_ENABLE_GRPC) + return; // Frontend does not support RestrictedFeatures +#endif + +#if defined(TRITON_ENABLE_HTTP) && defined(TRITON_ENABLE_METRICS) + if constexpr (std::is_same_v) { + return; // Metrics does not support RestrictedFeatures + } +#endif + +#ifdef TRITON_ENABLE_HTTP + if (std::is_same_v) { + key_prefix = ""; + } +#endif + +#ifdef TRITON_ENABLE_GRPC + if (std::is_same_v) { + key_prefix = "triton-grpc-protocol-"; + } +#endif + + std::string restricted_info; + ThrowIfError(GetValue(data, map_key, &restricted_info)); + + triton::common::TritonJson::Value rf_groups; + ThrowIfError(rf_groups.Parse(restricted_info)); + + std::string key, value, feature; + for (size_t group_idx = 0; group_idx < rf_groups.ArraySize(); group_idx++) { + triton::common::TritonJson::Value feature_group; + ThrowIfError(rf_groups.IndexAsObject(group_idx, &feature_group)); + + // Extract key and value + ThrowIfError(feature_group.MemberAsString("key", &key)); + ThrowIfError(feature_group.MemberAsString("value", &value)); + + triton::common::TritonJson::Value features; + ThrowIfError(feature_group.MemberAsArray("features", &features)); + + // Extract feature list + for (size_t feature_idx = 0; feature_idx < features.ArraySize(); + feature_idx++) { + ThrowIfError(features.IndexAsString(feature_idx, &feature)); + + rest_features.Insert( + RestrictedFeatures::ToCategory(feature), + std::make_pair(key_prefix + key, value)); + } + } + }; +}; }}} // namespace triton::server::python diff --git a/src/python/tritonfrontend/_c/tritonfrontend_pybind.cc b/src/python/tritonfrontend/_c/tritonfrontend_pybind.cc index cad21800bc..1474f55d2c 100644 --- a/src/python/tritonfrontend/_c/tritonfrontend_pybind.cc +++ b/src/python/tritonfrontend/_c/tritonfrontend_pybind.cc @@ -27,17 +27,6 @@ #include #include -#ifdef TRITON_ENABLE_GRPC -#include "../../../grpc/grpc_server.h" -#endif - - -#if defined(TRITON_ENABLE_HTTP) || defined(TRITON_ENABLE_METRICS) -#include "../../../http_server.h" -#endif - - -#include "triton/core/tritonserver.h" #include "tritonfrontend.h"