diff --git a/qa/L0_client_cancellation/client_cancellation_test.py b/qa/L0_client_cancellation/client_cancellation_test.py index d14ee230dd..c009b7916a 100755 --- a/qa/L0_client_cancellation/client_cancellation_test.py +++ b/qa/L0_client_cancellation/client_cancellation_test.py @@ -32,9 +32,9 @@ import asyncio import queue +import time import unittest from functools import partial -import time import numpy as np import test_util as tu @@ -54,6 +54,7 @@ def callback(user_data, result, error): else: user_data._completed_requests.put(result) + class ClientCancellationTest(tu.TestResultCollector): def setUp(self): self.model_name_ = "custom_identity_int32" @@ -69,13 +70,13 @@ def _record_end_time_ms(self): def _test_runtime_duration(self, upper_limit): self.assertTrue( - (self._end_time_ms - self._start_time_ms) < upper_limit, - "test runtime expected less than " - + str(upper_limit) - + "ms response time, got " - + str(self._end_time_ms - self._start_time_ms) - + " ms", - ) + (self._end_time_ms - self._start_time_ms) < upper_limit, + "test runtime expected less than " + + str(upper_limit) + + "ms response time, got " + + str(self._end_time_ms - self._start_time_ms) + + " ms", + ) def _prepare_request(self): self.inputs_ = [] @@ -85,7 +86,6 @@ def _prepare_request(self): self.inputs_[0].set_data_from_numpy(self.input0_data_) - def test_grpc_async_infer(self): # Sends a request using async_infer to a # model that takes 10s to execute. Issues @@ -115,13 +115,13 @@ def test_grpc_async_infer(self): # Wait until the results is captured via callback data_item = user_data._completed_requests.get() self.assertEqual(type(data_item), grpcclient.CancelledError) - + self._record_end_time_ms() self._test_runtime_duration(5000) def test_grpc_stream_infer(self): # Sends a request using async_stream_infer to a - # model that takes 10s to execute. Issues stream + # model that takes 10s to execute. Issues stream # closure with cancel_requests=True. The client # should return with appropriate exception within # 5s. @@ -134,9 +134,7 @@ def test_grpc_stream_infer(self): # The model is configured to take three seconds to send the # response. Expect an exception for small timeout values. - triton_client.start_stream( - callback=partial(callback, user_data) - ) + triton_client.start_stream(callback=partial(callback, user_data)) self._record_start_time_ms() for i in range(1): triton_client.async_stream_infer( @@ -148,11 +146,10 @@ def test_grpc_stream_infer(self): data_item = user_data._completed_requests.get() self.assertEqual(type(data_item), grpcclient.CancelledError) - + self._record_end_time_ms() self._test_runtime_duration(5000) - def test_aio_grpc_async_infer(self): # Sends a request using infer of grpc.aio to a # model that takes 10s to execute. Issues @@ -187,7 +184,6 @@ async def test_aio_infer(self): self._record_end_time_ms() self._test_runtime_duration(5000) - asyncio.run(test_aio_infer(self)) def test_aio_grpc_stream_infer(self): @@ -198,17 +194,23 @@ def test_aio_grpc_stream_infer(self): # 5s. async def test_aio_streaming_infer(self): async with aiogrpcclient.InferenceServerClient( - url="localhost:8001", verbose=True) as triton_client: + url="localhost:8001", verbose=True + ) as triton_client: + async def async_request_iterator(): for i in range(1): await asyncio.sleep(1) - yield {"model_name": self.model_name_, + yield { + "model_name": self.model_name_, "inputs": self.inputs_, - "outputs": self.outputs_} + "outputs": self.outputs_, + } self._prepare_request() self._record_start_time_ms() - response_iterator = triton_client.stream_infer(inputs_iterator=async_request_iterator(), get_call_obj=True) + response_iterator = triton_client.stream_infer( + inputs_iterator=async_request_iterator(), get_call_obj=True + ) streaming_call = await response_iterator.__anext__() async def cancel_streaming(streaming_call): @@ -228,5 +230,6 @@ async def handle_response(response_iterator): asyncio.run(test_aio_streaming_infer(self)) + if __name__ == "__main__": unittest.main()