Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tanmayv25 committed Sep 8, 2023
1 parent 957674a commit 6fb2ce2
Showing 1 changed file with 24 additions and 21 deletions.
45 changes: 24 additions & 21 deletions qa/L0_client_cancellation/client_cancellation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@

import asyncio
import queue
import time
import unittest
from functools import partial
import time

import numpy as np
import test_util as tu
Expand All @@ -54,6 +54,7 @@ def callback(user_data, result, error):
else:
user_data._completed_requests.put(result)


class ClientCancellationTest(tu.TestResultCollector):
def setUp(self):
self.model_name_ = "custom_identity_int32"
Expand All @@ -69,13 +70,13 @@ def _record_end_time_ms(self):

def _test_runtime_duration(self, upper_limit):
self.assertTrue(
(self._end_time_ms - self._start_time_ms) < upper_limit,
"test runtime expected less than "
+ str(upper_limit)
+ "ms response time, got "
+ str(self._end_time_ms - self._start_time_ms)
+ " ms",
)
(self._end_time_ms - self._start_time_ms) < upper_limit,
"test runtime expected less than "
+ str(upper_limit)
+ "ms response time, got "
+ str(self._end_time_ms - self._start_time_ms)
+ " ms",
)

def _prepare_request(self):
self.inputs_ = []
Expand All @@ -85,7 +86,6 @@ def _prepare_request(self):

self.inputs_[0].set_data_from_numpy(self.input0_data_)


def test_grpc_async_infer(self):
# Sends a request using async_infer to a
# model that takes 10s to execute. Issues
Expand Down Expand Up @@ -115,13 +115,13 @@ def test_grpc_async_infer(self):
# Wait until the results is captured via callback
data_item = user_data._completed_requests.get()
self.assertEqual(type(data_item), grpcclient.CancelledError)

self._record_end_time_ms()
self._test_runtime_duration(5000)

def test_grpc_stream_infer(self):
# Sends a request using async_stream_infer to a
# model that takes 10s to execute. Issues stream
# model that takes 10s to execute. Issues stream
# closure with cancel_requests=True. The client
# should return with appropriate exception within
# 5s.
Expand All @@ -134,9 +134,7 @@ def test_grpc_stream_infer(self):

# The model is configured to take three seconds to send the
# response. Expect an exception for small timeout values.
triton_client.start_stream(
callback=partial(callback, user_data)
)
triton_client.start_stream(callback=partial(callback, user_data))
self._record_start_time_ms()
for i in range(1):
triton_client.async_stream_infer(
Expand All @@ -148,11 +146,10 @@ def test_grpc_stream_infer(self):

data_item = user_data._completed_requests.get()
self.assertEqual(type(data_item), grpcclient.CancelledError)

self._record_end_time_ms()
self._test_runtime_duration(5000)


def test_aio_grpc_async_infer(self):
# Sends a request using infer of grpc.aio to a
# model that takes 10s to execute. Issues
Expand Down Expand Up @@ -187,7 +184,6 @@ async def test_aio_infer(self):
self._record_end_time_ms()
self._test_runtime_duration(5000)


asyncio.run(test_aio_infer(self))

def test_aio_grpc_stream_infer(self):
Expand All @@ -198,17 +194,23 @@ def test_aio_grpc_stream_infer(self):
# 5s.
async def test_aio_streaming_infer(self):
async with aiogrpcclient.InferenceServerClient(
url="localhost:8001", verbose=True) as triton_client:
url="localhost:8001", verbose=True
) as triton_client:

async def async_request_iterator():
for i in range(1):
await asyncio.sleep(1)
yield {"model_name": self.model_name_,
yield {
"model_name": self.model_name_,
"inputs": self.inputs_,
"outputs": self.outputs_}
"outputs": self.outputs_,
}

self._prepare_request()
self._record_start_time_ms()
response_iterator = triton_client.stream_infer(inputs_iterator=async_request_iterator(), get_call_obj=True)
response_iterator = triton_client.stream_infer(
inputs_iterator=async_request_iterator(), get_call_obj=True
)
streaming_call = await response_iterator.__anext__()

async def cancel_streaming(streaming_call):
Expand All @@ -228,5 +230,6 @@ async def handle_response(response_iterator):

asyncio.run(test_aio_streaming_infer(self))


if __name__ == "__main__":
unittest.main()

0 comments on commit 6fb2ce2

Please sign in to comment.