Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support continuous batching in sequence batch streaming case #3160

Merged
merged 34 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ea83cd9
Support continuous batching in sequence batch streaming case
lxning May 24, 2024
eab4d24
add test stateful sequence continuous batchng
lxning May 29, 2024
d4a777d
fmt
lxning May 29, 2024
55068cc
fix init atomicboolean
lxning May 29, 2024
8f6d366
update test and example
lxning May 30, 2024
7e7b339
fix open session test
lxning May 30, 2024
59cc12f
fix open session test
lxning May 30, 2024
287333f
set sequnce id
lxning May 30, 2024
1954417
set seq id in response
lxning May 30, 2024
e70ffb4
update test
lxning May 30, 2024
a41ecf5
fix wrong expected result
lxning May 30, 2024
5346f26
fixed test expectation
lxning May 30, 2024
81525c5
fmt
lxning May 30, 2024
b54f70b
update test path
lxning May 30, 2024
8428cea
simpify
lxning May 30, 2024
82f97f6
update for comments
lxning May 30, 2024
abba9df
remove sequence continuous parametger
lxning May 31, 2024
ab8fd67
update cancel
lxning May 31, 2024
c0ebab3
update cancel
lxning May 31, 2024
2830e32
update cleanup
lxning May 31, 2024
c3122f7
support mix mode stream and non-stream
lxning Jun 1, 2024
0167921
clean code
lxning Jun 1, 2024
a4e80ae
update test
lxning Jun 1, 2024
cc8216a
fix order
lxning Jun 1, 2024
b4934ec
update log
lxning Jun 1, 2024
b156aea
update headers
lxning Jun 1, 2024
74d5321
test mix mode
lxning Jun 2, 2024
0ad8645
update fmt
lxning Jun 2, 2024
fe18dea
increase counter
lxning Jun 2, 2024
4b43233
increase counter
lxning Jun 2, 2024
77b838b
add commnents
lxning Jun 2, 2024
182e831
update readme
lxning Jun 3, 2024
5f87195
update readme
lxning Jun 3, 2024
1ce8bf0
Added stop torchserve to unit tests
mreso Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,10 @@ handler:
### Step 3: Generate mar or tgz file

```bash
torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r requirements.txt --config-file model-config.yaml
torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r ../requirements.txt --config-file model-config.yaml
```

### Step 4: Start torchserve

```bash
torchserve --start --ncs --model-store model_store --models stateful.mar
```

### Step 6: Build GRPC Client
### Step 4: Build GRPC Client
The details can be found at [here](https://github.com/pytorch/serve/blob/master/docs/grpc_api.md).
* Install gRPC python dependencies
```bash
Expand All @@ -111,26 +105,23 @@ pip install -U grpcio protobuf grpcio-tools googleapis-common-protos

* Generate python gRPC client stub using the proto files
```bash
cd ../..
cd ../../..
python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
cd -
```

### Step 7: Run inference
### Step 5: Run inference
* Start TorchServe

```bash
torchserve --ncs --start --model-store models --model stateful.mar --ts-config config.properties
torchserve --ncs --start --model-store models --model stateful.mar --ts-config examples/stateful/config.properties
```

* Run sequence inference via GRPC client
```bash
cd ../../
python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
```

* Run sequence inference via HTTP
```bash
cd ../../
curl -H "ts_request_sequence_id: seq_0" http://localhost:8080/predictions/stateful -T examples/stateful/sample/sample1.txt
```
127 changes: 127 additions & 0 deletions examples/stateful/sequence_continuous_batching/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Stateful Inference

A stateful model possesses the ability to leverage interdependencies between successive inference requests. This type of model maintains a persistent state across inference requests, thereby establishing a linkage between the outcomes of prior inquiries and those that follow. Notable illustrations of stateful models encompass online speech recognition systems, such as the Long Short-Term Memory (LSTM) model. Employing stateful inference mandates that the model server adheres to the sequential order of inference requests, ensuring predictions build upon the previous outcomes.

Within this context, TorchServe offers a mechanism known as sequence batching. This approach involves the retrieval of an individual inference request from a particular sequence, followed by the combination of multiple requests originating from different sequences into a unified batch. Each request is associated with a unique sequence ID, which can be extracted using the "get_sequence_id" function of context.py. This `sequence ID` serves as a key employed by custom handlers to store and retrieve values within the backend cache store, fostering efficient management of stateful inference processes. Client can also reuse the `sequence ID` when a connection resumes as long as the sequence is not expired on the TorchServe side.

The following picture show the workflow of stateful inference. A job group has a job queue which stores incoming inference requests from a streaming. The max capacity of a job queue is defined by `maxSequenceJobQueueSize`. A sequence batch aggregator polls an inference request from each job group. A batch of requests is sent to backend.

![sequence batch](../../docs/images/stateful_batch.jpg)

This example serves as a practical showcase of employing stateful inference. Underneath the surface, the backend leverages an [LRU dictionary](https://github.com/amitdev/lru-dict), functioning as a caching layer. Users can choose different caching library in the handler implementation based on their own use cases.

### Step 1: Implement handler

stateful_handler.py is an example of stateful handler. It creates a cache `self.cache` by calling `[LRU](https://github.com/amitdev/lru-dict)`.

```python
def initialize(self, ctx: Context):
"""
Loads the model and Initializes the necessary artifacts
"""

super().initialize(ctx)
if self.context.model_yaml_config["handler"] is not None:
try:
self.cache = LRU(
int(self.context.model_yaml_config["handler"]["cache"]["capacity"]))
except KeyError:
logger.warn("No cache capacity was set! Using default value.")
self.cache = LRU(StatefulHandler.DEFAULT_CAPACITY)

self.initialized = True
```

Handler uses sequenceId (ie., `sequence_id = self.context.get_sequence_id(idx)`) as key to store and fetch values from `self.cache`.

```python
def preprocess(self, data):
"""
Preprocess function to convert the request input to a tensor(Torchserve supported format).
The user needs to override to customize the pre-processing

Args :
data (list): List of the data from the request input.

Returns:
tensor: Returns the tensor data of the input
"""

self.sequence_ids = {}
results = []
for idx, row in enumerate(data):
sequence_id = self.context.get_sequence_id(idx)

prev = int(0)
if self.cache.has_key(sequence_id):
prev = int(self.cache[sequence_id])

request = row.get("data") or row.get("body")
if isinstance(request, (bytes, bytearray)):
request = request.decode("utf-8")

val = prev + int(request)
self.cache[sequence_id] = val
results.append(val)

return results
```

### Step 2: Model configuration

Stateful inference has two parameters. TorchServe is able to process (maxWorkers * batchSize) sequences of inference requests of a model in parallel.
* sequenceMaxIdleMSec: the max idle in milliseconds of a sequence inference request of this stateful model. The default value is 0 (ie. this is not a stateful model.) TorchServe does not process the new inference request if the max idle timeout.
* maxSequenceJobQueueSize: the job queue size of an inference sequence of this stateful model. The default value is 1.


```yaml
#cat model-config.yaml

minWorkers: 2
maxWorkers: 2
batchSize: 4
sequenceMaxIdleMSec: 60000
maxSequenceJobQueueSize: 10
sequenceBatching: true

handler:
cache:
capacity: 4
```

### Step 3: Generate mar or tgz file

```bash
torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r ../requirements.txt --config-file model-config.yaml
```

### Step 4: Build GRPC Client
The details can be found at [here](https://github.com/pytorch/serve/blob/master/docs/grpc_api.md).
* Install gRPC python dependencies
```bash
git submodule init
pip install -U grpcio protobuf grpcio-tools googleapis-common-protos
```

* Generate python gRPC client stub using the proto files
```bash
cd ../../..
python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
```

### Step 5: Run inference
* Start TorchServe

```bash
torchserve --ncs --start --model-store models --model stateful.mar --ts-config examples/stateful/config.properties
```

* Run sequence inference via GRPC client
```bash
python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
```

* Run sequence inference via HTTP
```bash
curl -H "ts_request_sequence_id: seq_0" http://localhost:8080/predictions/stateful -T examples/stateful/sample/sample1.txt
```
11 changes: 11 additions & 0 deletions examples/stateful/sequence_continuous_batching/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
minWorkers: 2
maxWorkers: 2
batchSize: 4
maxNumSequence: 4
sequenceMaxIdleMSec: 10
maxSequenceJobQueueSize: 10
sequenceContinuousBatching: true

handler:
cache:
capacity: 4
157 changes: 157 additions & 0 deletions examples/stateful/sequence_continuous_batching/stateful_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import logging
from abc import ABC

from lru import LRU

from ts.context import Context
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class StatefulHandler(BaseHandler, ABC):
DEFAULT_CAPACITY = 10

def __init__(self):
super().__init__()
self.cache: LRU = None

def initialize(self, ctx: Context):
"""
Loads the model and Initializes the necessary artifacts
"""

ctx.cache = {}
if ctx.model_yaml_config["handler"] is not None:
try:
self.cache = LRU(
int(ctx.model_yaml_config["handler"]["cache"]["capacity"])
lxning marked this conversation as resolved.
Show resolved Hide resolved
)
except KeyError:
logger.error("No cache capacity was set! Using default value.")
self.cache = LRU(StatefulHandler.DEFAULT_CAPACITY)

self.initialized = True

def preprocess(self, data):
"""
Preprocess function to convert the request input to a tensor(Torchserve supported format).
The user needs to override to customize the pre-processing

Args :
data (list): List of the data from the request input.

Returns:
tensor: Returns the tensor data of the input
"""

results = []

for idx, row in enumerate(data):
sequence_id = self.context.get_sequence_id(idx)
self.context.set_response_header(
lxning marked this conversation as resolved.
Show resolved Hide resolved
idx, self.context.header_key_sequence_id, sequence_id
)
req_id = self.context.get_request_id(idx)

if self.context.get_request_header(
idx, self.context.header_key_sequence_start
):
prev = int(0)
elif self.cache.has_key(sequence_id):
prev = int(self.cache[sequence_id])
else:
prev = None
logger.error(
f"Not received sequence_start request for sequence_id:{sequence_id} before"
)

request = row.get("data") or row.get("body")
if isinstance(request, (bytes, bytearray)):
request = request.decode("utf-8")

if not self.context.cache.get(sequence_id, {}).get(req_id, {}):
self.context.cache[sequence_id] = {
req_id: {
"stopping_criteria": self._create_stopping_criteria(
req_id=req_id, seq_id=sequence_id, cache=self.context.cache
)
},
}

# -1: cancel
if int(request) == -1:
for r_id in self.context.cache[sequence_id].keys():
self.context.cache[sequence_id][r_id]["cancel"] = True
results.append(int(request))
elif prev is None:
logger.info(
f"Close the sequence:{sequence_id} without open session request"
)
self.context.cache[sequence_id][req_id]["end"] = True
self.context.set_response_header(
idx, self.context.header_key_sequence_end, sequence_id
)
results.append(int(request))
else:
val = prev + int(request)
self.cache[sequence_id] = val
# 0: end
if int(request) == 0:
self.context.cache[sequence_id][req_id]["end"] = True
self.context.set_response_header(
idx, self.context.header_key_sequence_end, sequence_id
)

results.append(val)

return results

def inference(self, data, *args, **kwargs):
return data

def postprocess(self, data):
"""
The post process function makes use of the output from the inference and converts into a
Torchserve supported response output.

Returns:
List: The post process function returns a list of the predicted output.
"""

return data

def clean_up_seq(self, seq_id):
# clean up
del self.cache[seq_id]
del self.context.cache[seq_id]

def clean_up_req(self, seq_id, req_id):
lxning marked this conversation as resolved.
Show resolved Hide resolved
# clean up
if seq_id in self.context.cache:
del self.context.cache[seq_id][req_id]

def _create_stopping_criteria(self, req_id, seq_id, cache):
class StoppingCriteria(object):
def __init__(self, outer, req_id, seq_id, cache):
self.req_id = req_id
self.seq_id = seq_id
self.cache = cache
lxning marked this conversation as resolved.
Show resolved Hide resolved
self.outer = outer
self.counter = 5

def __call__(self, res):
# sequence end
if self.cache[seq_id][req_id]["end"]:
self.outer.clean_up_seq(self.seq_id)
return True
# cancel
elif self.cache[seq_id][req_id]["cancel"] or self.counter == 0:
self.outer.clean_up_seq(self.seq_id, self.req_id)
return True
else:
self.counter -= 1

return False

return StoppingCriteria(outer=self, req_id=req_id, seq_id=seq_id, cache=cache)
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public class ModelConfig {
*/
private long sequenceMaxIdleMSec;
/**
* the job queue size of an inference sequence of this stateful model. The default value is 1.
* the job queue size of one inference sequence of this stateful model. The default value is 1.
*/
private int maxSequenceJobQueueSize = 1;
/** the max number of sequences can be accepted. The default value is 1. */
Expand All @@ -75,6 +75,11 @@ public class ModelConfig {
private boolean useVenv;
/** sequenceBatching is a flag to enable https://github.com/pytorch/serve/issues/2743 */
private boolean sequenceBatching;
/**
* sequenceContinuousBatching is a flag to enable continuous batching in sequenceBatching
* streaming use case so that a new inference request from the same sequence can be processed.
*/
private boolean sequenceContinuousBatching;

public static ModelConfig build(Map<String, Object> yamlMap) {
ModelConfig modelConfig = new ModelConfig();
Expand Down Expand Up @@ -222,6 +227,15 @@ public static ModelConfig build(Map<String, Object> yamlMap) {
"Invalid sequenceBatching: {}, should be true or false", v);
}
break;
case "sequenceContinuousBatching":
if (v instanceof Boolean) {
modelConfig.setSequenceContinuousBatching((boolean) v);
} else {
logger.warn(
"Invalid sequenceContinuousBatching: {}, should be true or false",
v);
}
break;
case "useVenv":
if (v instanceof Boolean) {
modelConfig.setUseVenv((boolean) v);
Expand Down Expand Up @@ -401,6 +415,15 @@ public void setSequenceBatching(boolean sequenceBatching) {
this.sequenceBatching = sequenceBatching;
}

public boolean isSequenceContinuousBatchingBatching() {
return sequenceContinuousBatching;
}

public void setSequenceContinuousBatching(boolean sequenceContinuousBatching) {
this.sequenceBatching = sequenceContinuousBatching;
this.sequenceContinuousBatching = sequenceContinuousBatching;
}

public int getMaxNumSequence() {
return maxNumSequence;
}
Expand Down
Loading
Loading