Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vLLM example with mistral7B model #2781

Merged
merged 12 commits into from
Dec 9, 2023
61 changes: 61 additions & 0 deletions examples/large_models/vllm/mistral/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Example showing inference with vLLM with mistralai/Mistral-7B-v0.1 model

This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on `mistralai/Mistral-7B-v0.1` model.
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/)

Install vLLM with the following

```
pip install -r requirements.txt
```
### Step 1: Login to HuggingFace

Login with a HuggingFace account
```
huggingface-cli login
# or using an environment variable
huggingface-cli login --token $HUGGINGFACE_TOKEN
```

```bash
python ../Download_model.py --model_path model --model_name mistralai/Mistral-7B-v0.1
```
Model will be saved in the following path, `mistralai/Mistral-7B-v0.1`.

### Step 2: Generate MAR file

Add the downloaded path to " model_path:" in `model-config.yaml` and run the following.

```bash
torch-model-archiver --model-name mistral7b --version 1.0 --handler custom_handler.py --config-file model-config.yaml -r requirements.txt --archive-format tgz
agunapal marked this conversation as resolved.
Show resolved Hide resolved
```

### Step 3: Add the mar file to model store

```bash
mkdir model_store
mv mistral7b.tar.gz model_store
```

### Step 3: Start torchserve

Update config.properties and start torchserve
agunapal marked this conversation as resolved.
Show resolved Hide resolved

```bash
torchserve --start --ncs --ts-config config.properties --model-store model_store --models mistral7b.tar.gz
```

### Step 4: Run inference

```bash
curl -v "http://localhost:8080/predictions/mistral7b" -T sample_text.txt
```

results in the following output
```
Mayonnaise is made of eggs, oil, vinegar, salt and pepper. Using an electric blender, combine all the ingredients and beat at high speed for 4 to 5 minutes.

Try it with some mustard and paprika mixed in, and a bit of sweetener if you like. But use real mayonnaise or it isn’t the same. Marlou

What in the world is mayonnaise?
```
5 changes: 5 additions & 0 deletions examples/large_models/vllm/mistral/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
inference_address=http://127.0.0.1:8080
management_address=http://127.0.0.1:8081
metrics_address=http://127.0.0.1:8082
enable_envvars_config=true
install_py_dep_per_model=true
88 changes: 88 additions & 0 deletions examples/large_models/vllm/mistral/custom_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import logging
from abc import ABC

import torch
import vllm
from vllm import LLM, SamplingParams

from ts.context import Context
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)
logger.info("vLLM version %s", vllm.__version__)


class LlamaHandler(BaseHandler, ABC):
agunapal marked this conversation as resolved.
Show resolved Hide resolved
"""
Transformers handler class for sequence, token classification and question answering.
"""

def __init__(self):
super(LlamaHandler, self).__init__()
self.max_new_tokens = None
self.tokenizer = None
self.initialized = False

def initialize(self, ctx: Context):
"""In this initialize function, the HF large model is loaded and
partitioned using DeepSpeed.
agunapal marked this conversation as resolved.
Show resolved Hide resolved
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
model_dir = ctx.system_properties.get("model_dir")
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"])
model_name = ctx.model_yaml_config["handler"]["model_name"]
model_path = ctx.model_yaml_config["handler"]["model_path"]
seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
torch.manual_seed(seed)

self.model = LLM(model=model_path)
agunapal marked this conversation as resolved.
Show resolved Hide resolved

logger.info("Model %s loaded successfully", ctx.model_name)
self.initialized = True

def preprocess(self, requests):
"""
Basic text preprocessing, based on the user's choice of application mode.
Args:
requests (list): A list of dictionaries with a "data" or "body" field, each
containing the input text to be processed.
Returns:
tuple: A tuple with two tensors: the batch of input ids and the batch of
attention masks.
"""
input_texts = [data.get("data") or data.get("body") for data in requests]
# return torch.as_tensor(input_texts, device=self.device)
input_texts = [
input_text.decode("utf-8")
for input_text in input_texts
if isinstance(input_text, (bytes, bytearray))
agunapal marked this conversation as resolved.
Show resolved Hide resolved
]
return input_texts

def inference(self, input_batch):
"""
Predicts the class (or classes) of the received text using the serialized transformers
checkpoint.
Args:
input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch
agunapal marked this conversation as resolved.
Show resolved Hide resolved
of attention masks, as returned by the preprocess function.
Returns:
list: A list of strings with the predicted values for each input text in the batch.
"""
logger.info(f"Input text is {input_batch}")
sampling_params = SamplingParams(max_tokens=self.max_new_tokens)
outputs = self.model.generate(input_batch, sampling_params=sampling_params)

logger.info("Generated text: %s", outputs)
return outputs

def postprocess(self, inference_output):
"""Post Process Function converts the predicted response into Torchserve readable format.
Args:
inference_output (list): It contains the predicted response of the input text.
Returns:
(list): Returns a list of the Predictions and Explanations.
"""
return [inference_output[0].outputs[0].text]
agunapal marked this conversation as resolved.
Show resolved Hide resolved
13 changes: 13 additions & 0 deletions examples/large_models/vllm/mistral/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 100
responseTimeout: 1200
deviceType: "gpu"

handler:
model_name: "mistralai/Mistral-7B-v0.1"
model_path: "/home/ubuntu/serve/examples/large_models/vllm/mistral/model/models--mistralai--Mistral-7B-v0.1/snapshots/5e9c98b96d071dce59368012254c55b0ec6f8658"
max_new_tokens: 100
manual_seed: 40
fast_kernels: True
agunapal marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions examples/large_models/vllm/mistral/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vllm
1 change: 1 addition & 0 deletions examples/large_models/vllm/mistral/sample_text.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
what is the recipe of mayonnaise?
3 changes: 3 additions & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1137,3 +1137,6 @@ Naver
FlashAttention
GenAI
prem
vLLM
mistralai
PagedAttention
Loading