Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for gRPC client-side cancellation #6278

Merged
merged 11 commits into from
Sep 19, 2023
249 changes: 249 additions & 0 deletions qa/L0_request_cancellation/client_cancellation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
#!/usr/bin/env python3

# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import sys

sys.path.append("../common")

import asyncio

Check notice

Code scanning / CodeQL

Unused import

Import of 'asyncio' is not used.
import queue
import time
import unittest
from functools import partial

import numpy as np
import test_util as tu
import tritonclient.grpc as grpcclient
import tritonclient.grpc.aio as aiogrpcclient

Check notice

Code scanning / CodeQL

Unused import

Import of 'aiogrpcclient' is not used.
from tritonclient.utils import InferenceServerException
Fixed Show fixed Hide fixed


class UserData:
def __init__(self):
self._completed_requests = queue.Queue()


def callback(user_data, result, error):
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)


class ClientCancellationTest(tu.TestResultCollector):
def setUp(self):
self.model_name_ = "custom_identity_int32"
self.input0_data_ = np.array([[10]], dtype=np.int32)
self._start_time_ms = 0
self._end_time_ms = 0

def _record_start_time_ms(self):
self._start_time_ms = int(round(time.time() * 1000))

def _record_end_time_ms(self):
self._end_time_ms = int(round(time.time() * 1000))

def _test_runtime_duration(self, upper_limit):
self.assertTrue(
(self._end_time_ms - self._start_time_ms) < upper_limit,
"test runtime expected less than "
+ str(upper_limit)
+ "ms response time, got "
+ str(self._end_time_ms - self._start_time_ms)
+ " ms",
)

def _prepare_request(self):
self.inputs_ = []
self.inputs_.append(grpcclient.InferInput("INPUT0", [1, 1], "INT32"))
self.outputs_ = []
self.outputs_.append(grpcclient.InferRequestedOutput("OUTPUT0"))

self.inputs_[0].set_data_from_numpy(self.input0_data_)

def test_grpc_async_infer(self):
# Sends a request using async_infer to a
# model that takes 10s to execute. Issues
# a cancellation request after 2s. The client
# should return with appropriate exception within
# 5s.
triton_client = grpcclient.InferenceServerClient(
url="localhost:8001", verbose=True
)
self._prepare_request()

user_data = UserData()

self._record_start_time_ms()

with self.assertRaises(InferenceServerException) as cm:
future = triton_client.async_infer(
model_name=self.model_name_,
inputs=self.inputs_,
callback=partial(callback, user_data),
outputs=self.outputs_,
)
time.sleep(2)
future.cancel()

data_item = user_data._completed_requests.get()
if type(data_item) == InferenceServerException:
raise data_item
self.assertIn("Locally cancelled by application!", str(cm.exception))

self._record_end_time_ms()
self._test_runtime_duration(5000)

def test_grpc_stream_infer(self):
# Sends a request using async_stream_infer to a
# model that takes 10s to execute. Issues stream
# closure with cancel_requests=True. The client
# should return with appropriate exception within
# 5s.
triton_client = grpcclient.InferenceServerClient(
url="localhost:8001", verbose=True
)

self._prepare_request()
user_data = UserData()

triton_client.start_stream(callback=partial(callback, user_data))
self._record_start_time_ms()

with self.assertRaises(InferenceServerException) as cm:
for i in range(1):
triton_client.async_stream_infer(
model_name=self.model_name_,
inputs=self.inputs_,
outputs=self.outputs_,
)
time.sleep(2)
triton_client.stop_stream(cancel_requests=True)
data_item = user_data._completed_requests.get()
if type(data_item) == InferenceServerException:
raise data_item
self.assertIn("Locally cancelled by application!", str(cm.exception))

self._record_end_time_ms()
self._test_runtime_duration(5000)


# Disabling AsyncIO cancellation testing. Enable once
# DLIS-5476 is implemented.
# def test_aio_grpc_async_infer(self):
# # Sends a request using infer of grpc.aio to a
# # model that takes 10s to execute. Issues
# # a cancellation request after 2s. The client
# # should return with appropriate exception within
# # 5s.
# async def cancel_request(call):
# await asyncio.sleep(2)
# self.assertTrue(call.cancel())
#
# async def handle_response(generator):
# with self.assertRaises(asyncio.exceptions.CancelledError) as cm:
# _ = await anext(generator)
#
# async def test_aio_infer(self):
# triton_client = aiogrpcclient.InferenceServerClient(
# url="localhost:8001", verbose=True
# )
# self._prepare_request()
# self._record_start_time_ms()
#
# generator = triton_client.infer(
# model_name=self.model_name_,
# inputs=self.inputs_,
# outputs=self.outputs_,
# get_call_obj=True,
# )
# grpc_call = await anext(generator)
#
# tasks = []
# tasks.append(asyncio.create_task(handle_response(generator)))
# tasks.append(asyncio.create_task(cancel_request(grpc_call)))
#
# for task in tasks:
# await task
#
# self._record_end_time_ms()
# self._test_runtime_duration(5000)
#
# asyncio.run(test_aio_infer(self))
#
# def test_aio_grpc_stream_infer(self):
# # Sends a request using stream_infer of grpc.aio
# # library model that takes 10s to execute. Issues
# # stream closure with cancel_requests=True. The client
# # should return with appropriate exception within
# # 5s.
# async def test_aio_streaming_infer(self):
# async with aiogrpcclient.InferenceServerClient(
# url="localhost:8001", verbose=True
# ) as triton_client:
#
# async def async_request_iterator():
# for i in range(1):
# await asyncio.sleep(1)
# yield {
# "model_name": self.model_name_,
# "inputs": self.inputs_,
# "outputs": self.outputs_,
# }
#
# self._prepare_request()
# self._record_start_time_ms()
# response_iterator = triton_client.stream_infer(
# inputs_iterator=async_request_iterator(), get_call_obj=True
# )
# streaming_call = await anext(response_iterator)
#
# async def cancel_streaming(streaming_call):
# await asyncio.sleep(2)
# streaming_call.cancel()
#
# async def handle_response(response_iterator):
# with self.assertRaises(asyncio.exceptions.CancelledError) as cm:
# async for response in response_iterator:
# self.assertTrue(False, "Received an unexpected response!")
#
# tasks = []
# tasks.append(asyncio.create_task(handle_response(response_iterator)))
# tasks.append(asyncio.create_task(cancel_streaming(streaming_call)))
#
# for task in tasks:
# await task
#
# self._record_end_time_ms()
# self._test_runtime_duration(5000)
#
# asyncio.run(test_aio_streaming_infer(self))


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

name: "custom_identity_int32"
backend: "identity"
max_batch_size: 1024
version_policy: { latest { num_versions: 1 }}
instance_group [ { kind: KIND_CPU } ]

input [
{
name: "INPUT0"
data_type: TYPE_INT32
dims: [ -1 ]

}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_INT32
dims: [ -1 ]
}
]

parameters [
{
key: "execute_delay_ms"
value: { string_value: "10000" }
}
]
50 changes: 50 additions & 0 deletions qa/L0_request_cancellation/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,23 @@ DATADIR=${DATADIR:="/data/inferenceserver/${REPO_VERSION}"}
RET=0

mkdir -p models/model/1
mkdir -p $DATADIR/custom_identity_int32/1

export CUDA_VISIBLE_DEVICES=0

RET=0

CLIENT_CANCELLATION_TEST=client_cancellation_test.py
TEST_RESULT_FILE='test_results.txt'

rm -f *.log
rm -f *.log.*

CLIENT_LOG=`pwd`/client.log
DATADIR=`pwd`/models
SERVER=/opt/tritonserver/bin/tritonserver
SERVER_ARGS="--model-repository=$DATADIR --log-verbose=1"
source ../common/util.sh

SERVER_LOG=server.log
LD_LIBRARY_PATH=/opt/tritonserver/lib:$LD_LIBRARY_PATH ./request_cancellation_test > $SERVER_LOG
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs rebasing. It looks like test.sh is created from scratch in here.

Expand All @@ -50,6 +67,39 @@ if [ $? -ne 0 ]; then
RET=1
fi

# gRPC client-side cancellation tests...
for i in test_grpc_async_infer \
test_grpc_stream_infer \
; do

SERVER_LOG=${i}.server.log
run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

set +e
python $CLIENT_CANCELLATION_TEST ClientCancellationTest.$i >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Test $i Failed\n***" >>$CLIENT_LOG
echo -e "\n***\n*** Test $i Failed\n***"
RET=1
else
check_test_results $TEST_RESULT_FILE 1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi

set -e
kill $SERVER_PID
wait $SERVER_PID
done

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
else
Expand Down