Skip to content

Commit

Permalink
Merge pull request #21 from alecgunny/triton-stateful-backend
Browse files Browse the repository at this point in the history
Triton stateful backend
  • Loading branch information
alecgunny authored Sep 30, 2022
2 parents 94064d0 + 2ecfd1e commit 716352a
Show file tree
Hide file tree
Showing 19 changed files with 1,530 additions and 843 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ jobs:
aeriel: 'hermes/hermes.aeriel/**'
cloudbreak: 'hermes/hermes.cloudbreak/**'
quiver: 'hermes/hermes.quiver/**'
stillwater: 'hermes/hermes.stillwater/**'
# stillwater: 'hermes/hermes.stillwater/**'

test:
runs-on: ubuntu-latest
Expand All @@ -31,7 +32,7 @@ jobs:
fail-fast: false
matrix:
python-version: ['3.8', '3.9', '3.10']
poetry-version: [1.2.0a2, 1.2.0b2]
poetry-version: ['1.2.0a2', '1.2.0b2']
library: ${{ fromJSON(needs.changes.outputs.libraries) }}
steps:
- uses: actions/checkout@v2
Expand Down
157 changes: 99 additions & 58 deletions hermes/hermes.aeriel/hermes/aeriel/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sys
import time
from collections import defaultdict
from copy import deepcopy
from queue import Empty, Queue
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -169,6 +170,7 @@ def __init__(
else:
self.clock = None
self.callback_q = Queue()
self._sequences = {}

def _build_inputs(
self, batch_size: int
Expand Down Expand Up @@ -204,31 +206,27 @@ def _build_inputs(
datatype=metadata_input.datatype,
)
for step in config.ensemble_scheduling.step:
# see if this input corresponds to the input
# for a snapshotter model. TODO: come up
# with a better way of identifying snapshotter
# models than by using the name
if (
len(step.input_map) == 1
and list(step.input_map.values())[0] == config_input.name
and step.model_name.startswith("snapshotter")
):
# see if this input corresponds to
# the input for a snapshotter model
input_map = list(step.input_map.values())
if len(input_map) == 1 and input_map[0] == config_input.name:
model_config = self.client.get_model_config(
step.model_name
)
model_config = model_config.config
if len(model_config.sequence_batching.state) == 0:
continue

# only support streaming with batch size 1
shape[0] = 1

# now read the model config for the snapshotter to
# figure out what the names of its outputs are
snapshotter_config = self.client.get_model_config(
step.model_name
).config

# iterate through the outputs of the snapshotter
# and figure out how many channels each of its
# states need to have. Record them in a dict
# mapping from the name of the state to the
# number of channels in the state
channel_map = {}
for x in snapshotter_config.output:
for x in model_config.output:
map_key = step.output_map[x.name]

# look for the model whose input
Expand Down Expand Up @@ -343,6 +341,71 @@ def __exit__(self, *exc_args):
self.client._stream._request_queue.put(None)
self.client.__exit__(*exc_args)

def _validate_inputs(
self,
x: Union[np.ndarray, Dict[str, np.ndarray]],
sequence_id: Optional[int] = None,
):
"""
Normalize an inference input and grab the gRPC
InferenceInput objects it will we used to fill
with data in a thread-safe fashion.
"""
# if we just passed a single array, make sure we only
# have one input or state that we need to pass it to
if not isinstance(x, dict):
if len(self.inputs) + self.num_states > 1:
raise ValueError(
"Only passed a single input array, but "
"model {} has {} inputs and {} states".format(
self.model_name, len(self.inputs), len(self.states)
)
)
elif len(self.inputs) > 0:
# if we only have a non-stateful input, set `x`
# up to keep the parsing methods below more standard
x = {self.inputs[0].name(): x}
else:
# same for if we only have a stateful input
state_name = list(self.states[0][1].keys())[0]
x = {state_name: x}

# now do some checks on the sequence that this input
# belongs to (if any) and grab or create the corresponding
# InferenceInputs it will be used to fill with data
if sequence_id is None and len(self.states) > 0:
# enforce that we provide a sequence id if
# there are any stateful inputs. TODO: should
# we do the same if there are any stateful outputs?
# Or just states somewhere in the model in general?
raise ValueError(
"Must provide sequence id for model with states {}".format(
[i[0].name() for i in self.states]
)
)
elif sequence_id is not None and self.num_states == 0:
raise ValueError(
"Specified sequence id {} for request to "
"non-stateful model {}".format(sequence_id, self.model_name)
)
elif sequence_id is not None and sequence_id not in self._sequences:
# this is a new sequence, so create a fresh set of inputs for it
# to make doing inference across multiple streams thread-safe
logging.debug(
f"Creating new inputs and states for sequence {sequence_id}"
)
inputs, states = deepcopy(self.inputs), deepcopy(self.states)
self._sequences[sequence_id] = (inputs, states)
elif sequence_id is not None:
# otherwise this is an existing sequence, so grab
# the corresponding inputs and states
inputs, states = self._sequences[sequence_id]
elif self.num_states == 0:
# we're not doing stateful inference, so there's
# no sequences to keep track of in the first place
inputs, states = self.inputs, []
return x, inputs, states, sequence_id

def infer(
self,
x: Union[np.ndarray, Dict[str, np.ndarray]],
Expand Down Expand Up @@ -402,44 +465,11 @@ def infer(
no stateful inputs.
"""

# if we just passed a single array, make sure we only
# have one input or state that we need to pass it to
if not isinstance(x, dict):
if len(self.inputs) + self.num_states > 1:
raise ValueError(
"Only passed a single input array, but "
"model {} has {} inputs and {} states".format(
self.model_name, len(self.inputs), len(self.states)
)
)
elif len(self.inputs) > 0:
# if we only have a non-stateful input, set `x`
# up to keep the parsing methods below more standard
x = {self.inputs[0].name(): x}
else:
# same for if we only have a stateful input
state_name = list(self.states[0][1].keys())[0]
x = {state_name: x}

if request_id is None:
# if a request_id wasn't specified, give it a random
# one that will (probably) be unique
request_id = str(random.randint(0, 1e16))

if sequence_id is None and len(self.states) > 0:
# enforce that we provide a sequence id if
# there are any stateful inputs. TODO: should
# we do the same if there are any stateful outputs?
# Or just states somewhere in the model in general?
raise ValueError(
"Must provide sequence id for model with states {}".format(
[i[0].name() for i in self.states]
)
)
x, inputs, states, sequence_id = self._validate_inputs(x, sequence_id)

# if we have any non-stateful inputs, set their
# input value using the corresponding package
for input in self.inputs:
for input in inputs:
name = input.name()
try:
value = x[name]
Expand All @@ -456,8 +486,8 @@ def infer(
# the updates for each state and set the input
# message value using those updates. Do checks
# on the sequence start and end values
for state, channel_map in self.states:
states = []
for state, channel_map in states:
state_values = []

# for each update in the state, try to
# get the update for it and do the checks
Expand All @@ -468,15 +498,20 @@ def infer(
raise ValueError(f"Missing state {name}")

# add the update to our running list of updates
states.append(value[None])
state_values.append(value[None])

# if we have more than one state, combine them
# into a single tensor along the channel axis
if len(states) > 1:
state = np.concatenate(states, axis=1)
if len(state_values) > 1:
state = np.concatenate(state_values, axis=1)
state.set_data_from_numpy(state)
else:
state.set_data_from_numpy(states[0])
state.set_data_from_numpy(state_values[0])

# if a request_id wasn't specified, give it a random
# one that will (probably) be unique
if request_id is None:
request_id = random.randint(0, 1e16)

# keep track of in-flight times if we're profiling
if self.clock is not None:
Expand All @@ -487,6 +522,8 @@ def infer(
# to the request id so we can parse in the callback
if sequence_id is not None:
request_id = f"{request_id}_{sequence_id}"
else:
request_id = str(request_id)

if len(self.states) > 0:
# make a streaming inference if we have input states
Expand All @@ -495,13 +532,17 @@ def infer(
self.client.async_stream_infer(
self.model_name,
model_version=str(self.model_version),
inputs=self.inputs + [x[0] for x in self.states],
inputs=inputs + [x[0] for x in states],
request_id=request_id,
sequence_id=sequence_id,
sequence_start=sequence_start,
sequence_end=sequence_end,
timeout=60,
)

if sequence_end:
# remove the inputs for this sequence if it's complete
self._sequences.pop(sequence_id)
else:
self.client.async_infer(
self.model_name,
Expand Down
2 changes: 1 addition & 1 deletion hermes/hermes.aeriel/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion hermes/hermes.aeriel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ packages = [

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
tritonclient = {extras = ["all"], version = "^2.18"}
tritonclient = {extras = ["all"], version = "^2.22"}

numpy = {version = "^1.22", optional = true}
spython = {version = "^0.1", optional = true}

Expand Down
18 changes: 11 additions & 7 deletions hermes/hermes.aeriel/tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,17 @@ def test_inference_client(mock1, mock2, num_inputs, num_states, version):
if (num_inputs + num_states) == 1:
if num_states == 1:
x = np.random.randn(*states[0][0].shape)
kwargs = {
"sequence_id": 1001,
"sequence_start": MagicMock(),
}
else:
x = np.random.randn(1, num_channels, dim)
kwargs = {}

dummy = MagicMock()
client.infer(
x, request_id=10, sequence_id=1001, sequence_start=dummy
)
postprocessor.assert_called_with(
x.reshape(-1)[-1] + 1, 10, 1001
)
client.infer(x, request_id=10, **kwargs)
x_expected = x.reshape(-1)[-1] + 1
if num_states > 0:
postprocessor.assert_called_with(x_expected, 10, 1001)
else:
postprocessor.assert_called_with(x_expected, 10, None)
25 changes: 19 additions & 6 deletions hermes/hermes.quiver/hermes/quiver/exporters/tensorrt/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,17 @@ def _convert_network(
# batch size, the config's value will read 0,
# so replace with 1 as a default here for streaming
builder.max_batch_size = max(model_config.max_batch_size, 1)
if use_fp16:
builder.fp16_mode = True
builder.strict_type_constraints = True

# if any of the inputs have a variable
# length batch dimension, create an
# optimization profile for that input with
# the most optimized batch size being the largest
config = stack.enter_context(builder.create_builder_config())
config.max_workspace_size = 1 << 28
if use_fp16:
config.flags |= 1 << int(trt.BuilderFlag.FP16)
# builder.strict_type_constraints = True

for input in model_config.input:
if input.dims[0] != -1:
# this input doesn't have a variable
Expand Down Expand Up @@ -78,7 +79,17 @@ def _convert_network(
)
)
parser = stack.enter_context(trt.OnnxParser(network, logger))
parser.parse(model_binary)
success = parser.parse(model_binary)
if not success:
errors = [parser.get_error(i) for i in range(parser.num_errors)]
errors = "\n".join(map(str, errors))
completed_layers = [
network.get_layer(i).name for i in range(network.num_layers)
]
msg = "Parsing ONNX binary failed. Completed layers:\n"
msg += "\n".join(completed_layers)
msg += "\n\nRaised errors:\n" + errors
raise RuntimeError(msg)

if len(model_config.output) == 1 and network.num_outputs == 0:
# if we only have a single output and for whatever
Expand Down Expand Up @@ -112,11 +123,13 @@ def _convert_network(
if len(network_output.shape) != len(output.dims):
raise ValueError(
"Number of dimensions {} specified for "
"output {} not equal to number {} found "
"in TensorRT network".format(
"output {} with shape {} not equal to number {} found "
"in TensorRT network with shape {}".format(
len(output.dims),
output.name,
output.dims,
len(network_output.shape),
network_output.shape,
)
)

Expand Down
Loading

0 comments on commit 716352a

Please sign in to comment.