Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stateful inference #2513

Merged
merged 59 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ed5239e
stateful inference-core layer
lxning Aug 1, 2023
0794f54
add grpc layer
lxning Aug 5, 2023
4d55643
add google rpc submodule
lxning Aug 7, 2023
a857307
fmt
lxning Aug 15, 2023
4ae7404
update sequence batch img
lxning Aug 15, 2023
5c0dd97
update sequence batch img
lxning Aug 15, 2023
0651806
fmt
lxning Aug 15, 2023
3e16993
delete used file
lxning Aug 15, 2023
6aee437
fmt
lxning Aug 15, 2023
91b9f99
fmt
lxning Aug 15, 2023
a3f84eb
fix log and update doc
lxning Aug 16, 2023
c60f390
update log
lxning Aug 16, 2023
f5c7707
fmt
lxning Aug 16, 2023
dd23216
merge master and fix conflict
lxning Sep 27, 2023
1ea33cf
make BatchAggregator as base
lxning Sep 27, 2023
fdb03c9
fix conflict
lxning Sep 27, 2023
c3a2cca
fix conflict
lxning Sep 27, 2023
ba1bc45
add SequenceBatchAggregator
lxning Sep 29, 2023
f6c888d
update ci for submodule
lxning Sep 29, 2023
d723754
merge master
lxning Oct 4, 2023
077bf27
refactor
lxning Oct 5, 2023
bd3296a
fmt
lxning Oct 5, 2023
d3777e5
merge master
lxning Oct 5, 2023
d1e5e8d
fmt
lxning Oct 5, 2023
c24f1b8
fix lint
lxning Oct 6, 2023
3831ff5
code refactor
lxning Oct 8, 2023
8840a1c
fix conflict merging master
lxning Oct 11, 2023
79f2e66
update readme
lxning Oct 11, 2023
91a201e
update readme
lxning Oct 11, 2023
a431e78
fmt
lxning Oct 11, 2023
7c81022
Merge branch 'master' into feat/stateful
lxning Oct 11, 2023
fef7780
fmt
lxning Oct 11, 2023
cb3d232
Merge branch 'master' into feat/stateful
lxning Oct 12, 2023
67436d3
test workflow
lxning Oct 12, 2023
fe663fb
revert test
lxning Oct 12, 2023
421f31e
revert test response
lxning Oct 12, 2023
7f7bb69
fmt
lxning Oct 25, 2023
838a896
fmt
lxning Oct 11, 2023
0698bab
fix conflict
lxning Oct 25, 2023
4749b74
update readme
lxning Oct 27, 2023
80053ca
allow number ofjobGroup is larger than batchsize
lxning Oct 28, 2023
44d3986
fmt
lxning Oct 11, 2023
8879393
Merge branch 'master' into feat/stateful
lxning Oct 28, 2023
b05e653
fix typo
lxning Oct 28, 2023
5f7125e
add stateful test data
lxning Oct 28, 2023
9bb9245
fmt
lxning Oct 28, 2023
2f83255
Merge branch 'master' into feat/stateful
lxning Oct 28, 2023
a592f10
fmt
lxning Oct 28, 2023
4b9145b
fmt
lxning Oct 30, 2023
5fe05cd
fmt
lxning Oct 11, 2023
8e7ce9e
Merge branch 'master' into feat/stateful
lxning Oct 30, 2023
0a90a87
set default maxNumSequence
lxning Nov 1, 2023
fb9cdb5
fmt
lxning Nov 3, 2023
627a31e
Merge branch 'master' into feat/stateful
lxning Nov 3, 2023
876d83d
fmt
lxning Oct 11, 2023
4b19885
Merge branch 'master' into feat/stateful
lxning Nov 3, 2023
c5a0708
revert back config.properties
lxning Nov 3, 2023
6dc374a
fmt
lxning Nov 5, 2023
d4ea03d
Merge branch 'master' into feat/stateful
lxning Nov 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/benchmark_nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ jobs:
java-version: '17'
- name: Checkout TorchServe
uses: actions/checkout@v3
with:
submodules: recursive
- name: Install dependencies
run: |
sudo apt-get update -y
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/ci_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ jobs:
java-version: '17'
- name: Checkout TorchServe
uses: actions/checkout@v3
with:
submodules: recursive
- name: Install dependencies
run: |
python ts_scripts/install_dependencies.py --environment=dev
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/ci_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ jobs:
java-version: '17'
- name: Checkout TorchServe
uses: actions/checkout@v3
with:
submodules: recursive
- name: Install dependencies
run: |
python ts_scripts/install_dependencies.py --environment=dev --cuda=cu121
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/codeql.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v3
with:
submodules: recursive

- name: Setup Python 3.8
uses: actions/setup-python@v4
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/docker-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ jobs:
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v3
with:
submodules: recursive

- name: Test build_image.sh script with custom tagging and gpu flag
working-directory: docker
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/docker-nightly-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ jobs:
architecture: x64
- name: Checkout TorchServe
uses: actions/checkout@v3
with:
submodules: recursive
- name: Login to Docker
env:
DOCKER_PASSWORD: ${{secrets.DOCKER_PASSWORD}}
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/regression_tests_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ jobs:
java-version: '17'
- name: Checkout TorchServe
uses: actions/checkout@v3
with:
submodules: recursive
- name: Install dependencies
run: |
python ts_scripts/install_dependencies.py --environment=dev
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/regression_tests_cpu_binaries.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ jobs:
binaries: ["pypi", "conda"]
steps:
- uses: actions/checkout@v3
with:
submodules: recursive
- name: Setup conda with Python ${{ matrix.python-version }}
uses: s-weigand/setup-conda@v1
with:
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/regression_tests_docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ jobs:
docker system prune --all --volumes -f
- name: Checkout TorchServe
uses: actions/checkout@v3
with:
submodules: recursive
- name: Branch name
run: |
echo $GITHUB_REF_NAME
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/regression_tests_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ jobs:
java-version: '17'
- name: Checkout TorchServe
uses: actions/checkout@v3
with:
submodules: recursive
- name: Install dependencies
run: |
python ts_scripts/install_dependencies.py --environment=dev --cuda=cu121
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/regression_tests_gpu_binaries.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ jobs:
java-version: '17'
- name: Checkout TorchServe
uses: actions/checkout@v3
with:
submodules: recursive
- uses: conda-incubator/setup-miniconda@v2
with:
miniconda-version: "latest"
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/torchserve-nightly-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ jobs:
- run: conda install -y conda-build anaconda-client
- name: Checkout TorchServe
uses: actions/checkout@v3
with:
submodules: recursive
- name: Install dependencies
run: |
python ts_scripts/install_dependencies.py --environment=dev
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/google/rpc"]
path = third_party/google/rpc
url = https://github.com/googleapis/googleapis.git
7 changes: 4 additions & 3 deletions docs/grpc_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ Run following commands to Register, run inference and unregister, densenet161 mo
```bash
git clone https://github.com/pytorch/serve
cd serve
git submodule init
```

- Install gRPC python dependencies

```bash
pip install -U grpcio protobuf grpcio-tools
pip install -U grpcio protobuf grpcio-tools googleapis-common-protos
```

- Start torchServe
Expand All @@ -51,7 +52,7 @@ torchserve --start --model-store models/
- Generate python gRPC client stub using the proto files

```bash
python -m grpc_tools.protoc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
```

- Register densenet161 model
Expand Down Expand Up @@ -95,4 +96,4 @@ def handle(data, context):
for i in range (3):
send_intermediate_predict_response(["intermediate_response"], context.request_ids, "Intermediate Prediction success", 200, context)
return ["hello world "]
```
```
Binary file added docs/images/stateful_batch.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
lxning marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import logging
from abc import ABC

import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

from ts.context import Context
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)
logger.info("Transformers version %s", transformers.__version__)


class LlamaHandler(BaseHandler, ABC):
"""
Transformers handler class for sequence, token classification and question answering.
"""

def __init__(self):
super(LlamaHandler, self).__init__()
self.max_length = None
self.max_new_tokens = None
self.tokenizer = None
self.initialized = False

def initialize(self, ctx: Context):
"""In this initialize function, the HF large model is loaded and
partitioned using DeepSpeed.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
model_dir = ctx.system_properties.get("model_dir")
self.max_length = int(ctx.model_yaml_config["handler"]["max_length"])
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"])
model_name = ctx.model_yaml_config["handler"]["model_name"]
model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}'
seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
torch.manual_seed(seed)

logger.info("Model %s loading tokenizer", ctx.model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="balanced",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_8bit=True,
trust_remote_code=True,
)
if ctx.model_yaml_config["handler"]["fast_kernels"]:
from optimum.bettertransformer import BetterTransformer

try:
self.model = BetterTransformer.transform(self.model)
except RuntimeError as error:
logger.warning(
"HuggingFace Optimum is not supporting this model,for the list of supported models, please refer to this doc,https://huggingface.co/docs/optimum/bettertransformer/overview"
)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)

logger.info("Model %s loaded successfully", ctx.model_name)
self.initialized = True

def preprocess(self, requests):
"""
Basic text preprocessing, based on the user's choice of application mode.
Args:
requests (list): A list of dictionaries with a "data" or "body" field, each
containing the input text to be processed.
Returns:
tuple: A tuple with two tensors: the batch of input ids and the batch of
attention masks.
"""
input_texts = [data.get("data") or data.get("body") for data in requests]
input_ids_batch, attention_mask_batch = [], []
for input_text in input_texts:
input_ids, attention_mask = self.encode_input_text(input_text)
input_ids_batch.append(input_ids)
attention_mask_batch.append(attention_mask)
input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.model.device)
attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device)
return input_ids_batch, attention_mask_batch

def encode_input_text(self, input_text):
"""
Encodes a single input text using the tokenizer.
Args:
input_text (str): The input text to be encoded.
Returns:
tuple: A tuple with two tensors: the encoded input ids and the attention mask.
"""
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
logger.info("Received text: '%s'", input_text)
inputs = self.tokenizer.encode_plus(
input_text,
max_length=self.max_length,
padding=False,
add_special_tokens=True,
return_tensors="pt",
truncation=True,
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
return input_ids, attention_mask

def inference(self, input_batch):
"""
Predicts the class (or classes) of the received text using the serialized transformers
checkpoint.
Args:
input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch
of attention masks, as returned by the preprocess function.
Returns:
list: A list of strings with the predicted values for each input text in the batch.
"""
input_ids_batch, attention_mask_batch = input_batch
input_ids_batch = input_ids_batch.to(self.device)
outputs = self.model.generate(
input_ids_batch,
attention_mask=attention_mask_batch,
max_length=self.max_new_tokens,
)

inferences = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

logger.info("Generated text: %s", inferences)
return inferences

def postprocess(self, inference_output):
"""Post Process Function converts the predicted response into Torchserve readable format.
Args:
inference_output (list): It contains the predicted response of the input text.
Returns:
(list): Returns a list of the Predictions and Explanations.
"""
return inference_output
Loading
Loading