Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example for Llama2 on Inf2 #2458

Merged
merged 27 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2838ad0
add llama example
Jul 11, 2023
73a0ba4
Update documentation and increase response timeout config
Jul 12, 2023
f5c0855
Update example to use decapoda research llama model
Jul 14, 2023
5c72745
update documentation
Jul 14, 2023
0d06004
refactor save_split_checkpoint script and use no-archive format
Jul 19, 2023
c5720c5
Add llama2 text completion example with streaming response support
Jul 27, 2023
89e4ffb
Update handler to only support batch size 1
Aug 1, 2023
483004b
Add batching support with stream response
Aug 8, 2023
83dc576
add support for batching and fix documentation
Aug 9, 2023
7920a96
update example for OPT model
Aug 9, 2023
a0c199c
Add support for micro batching
Aug 16, 2023
3e0697b
Enable saving tokenizer model when saving split checkpoints
Aug 23, 2023
2993154
Move HF batch streamer to handler_utils
Aug 23, 2023
98e2a94
Add support for amp
Aug 23, 2023
f50ac63
Refactor microbatch implementation to only handle batch index
Aug 26, 2023
9af1611
fix requirements package versions
Sep 6, 2023
b0392cc
Add support for neuronx compilation artifact caching
Sep 6, 2023
5ac4696
Update Readme to include more details about batch size and microbatch…
Sep 13, 2023
3da4a78
add config.properties file
Sep 13, 2023
b747cd3
Update Readme to include details about AOT compilation
Sep 13, 2023
7c1b130
Refactor custom handler
Sep 14, 2023
80eb640
add file links in readme
Sep 17, 2023
e55fd86
Merge branch 'master' into naman-inf2-example-refactor
namannandan Sep 18, 2023
36b0d96
Add more details in Readme about using precompiled model artifacts
Sep 18, 2023
5176dce
fix linter error
Sep 18, 2023
a983a01
Include model checkpoint files in a separate directory
Sep 19, 2023
cfaf385
Merge branch 'master' into naman-inf2-example-refactor
namannandan Sep 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions examples/large_models/inferentia2/llama2/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Large model inference on Inferentia2

This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) model on [AWS Inferentia2](https://aws.amazon.com/ec2/instance-types/inf2/) for text completion with [micro batching](https://github.com/pytorch/serve/tree/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/examples/micro_batching) and [streaming response](https://github.com/pytorch/serve/blob/96450b9d0ab2a7290221f0e07aea5fda8a83efaf/docs/inference_api.md#curl-example-1) support.

Inferentia2 uses [Neuron SDK](https://aws.amazon.com/machine-learning/neuron/) which is built on top of PyTorch XLA stack. For large model inference [`transformers-neuronx`](https://github.com/aws-neuron/transformers-neuronx) package is used that takes care of model partitioning and running inference.

**Note**: To run the model on an Inf2 instance, the model gets compiled as a preprocessing step. As part of the compilation process, to generate the model graph, a specific batch size is used. Following this, when running inference, we need to pass input which matches the batch size that was used during compilation. Model compilation and input padding to match compiled model batch size is taken care of by the [custom handler](inf2_handler.py) in this example.

namannandan marked this conversation as resolved.
Show resolved Hide resolved
The batch size and micro batch size configurations are present in [model-config.yaml](model-config.yaml). The batch size indicates the maximum number of requests torchserve will aggregate and send to the custom handler within the batch delay.
The batch size is chosen to be a relatively large value, say 16 since micro batching enables running the preprocess(tokenization) and inference steps in parallel on the micro batches. The micro batch size is the batch size used for the Inf2 model compilation.
Since compilation batch size can influence compile time and also constrained by the Inf2 instance type, this is chosen to be a relatively smaller value, say 4.

This example also demonstrates the utilization of neuronx cache to store inf2 model compilation artifacts using the `NEURONX_CACHE` and `NEURONX_DUMP_TO` environment variables in the custom handler.
When the model is loaded for the first time, the model is compiled for the configured micro batch size and the compilation artifacts are saved to the neuronx cache.
On subsequent model load, the compilation artifacts in the neuronx cache serves as `Ahead of Time(AOT)` compilation artifacts and significantly reduces the model load time.
For convenience, the compiled model artifacts for this example are made available on the Torchserve model zoo: `s3://torchserve/mar_files/llama-2-13b-neuronx-b4`\
Instructions on how to use the AOT compiled model artifacts is shown below.

### Step 1: Inf2 instance

Get an Inf2 instance(Note: This example was tested on instance type:`inf2.24xlarge`), ssh to it, make sure to use the following DLAMI as it comes with PyTorch and necessary packages for AWS Neuron SDK pre-installed.
namannandan marked this conversation as resolved.
Show resolved Hide resolved
DLAMI Name: ` Deep Learning AMI Neuron PyTorch 1.13 (Ubuntu 20.04) 20230720 Amazon Machine Image (AMI)` or higher.

### Step 2: Package Installations

Follow the steps below to complete package installations

```bash
sudo apt-get update
sudo apt-get upgrade

# Update Neuron Runtime
sudo apt-get install aws-neuronx-collectives=2.* -y
namannandan marked this conversation as resolved.
Show resolved Hide resolved
sudo apt-get install aws-neuronx-runtime-lib=2.* -y

# Activate Python venv
source /opt/aws_neuron_venv_pytorch/bin/activate

# Clone Torchserve git repository
git clone https://github.com/pytorch/serve.git
cd serve

# Install dependencies
python ts_scripts/install_dependencies.py --neuronx --environment=dev

# Install torchserve and torch-model-archiver
python ts_scripts/install_from_src.py

# Navigate to `examples/large_models/inferentia2/llama2` directory
cd examples/large_models/inferentia2/llama2/

# Install additional necessary packages
python -m pip install -r requirements.txt
```

### Step 3: Save the model artifacts compatible with `transformers-neuronx`
In order to use the pre-compiled model artifacts, copy them from the model zoo using the command shown below and skip to **Step 5**
```bash
aws s3 cp s3://torchserve/mar_files/llama-2-13b-neuronx-b4/ llama-2-13b --recursive
```

In order to download and compile the Llama2 model from scratch for support on Inf2:\
Request access to the Llama2 model\
https://huggingface.co/meta-llama/Llama-2-13b-hf

Login to Huggingface
```bash
huggingface-cli login
```

Run the `inf2_save_split_checkpoints.py` script
```bash
python ../util/inf2_save_split_checkpoints.py --model_name meta-llama/Llama-2-13b-hf --save_path './llama-2-13b-split'
namannandan marked this conversation as resolved.
Show resolved Hide resolved
```


### Step 4: Package model artifacts

```bash
torch-model-archiver --model-name llama-2-13b --version 1.0 --handler inf2_handler.py --extra-files ./llama-2-13b-split -r requirements.txt --config-file model-config.yaml --archive-format no-archive
namannandan marked this conversation as resolved.
Show resolved Hide resolved
```

### Step 5: Add the model artifacts to model store

```bash
mkdir model_store
mv llama-2-13b model_store
```

### Step 6: Start torchserve

```bash
namannandan marked this conversation as resolved.
Show resolved Hide resolved
torchserve --ncs --start --model-store model_store --ts-config config.properties
```

### Step 7: Register model

```bash
curl -X POST "http://localhost:8081/models?url=llama-2-13b"
```

### Step 8: Run inference

```bash
python test_stream_response.py
```

### Step 9: Stop torchserve

```bash
torchserve --stop
```
1 change: 1 addition & 0 deletions examples/large_models/inferentia2/llama2/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
install_py_dep_per_model=true
159 changes: 159 additions & 0 deletions examples/large_models/inferentia2/llama2/inf2_handler.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to a have unit test for the handler. You can mock out inferentia and model related parts. This example shows how to mock the context etc

Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import logging
import os
from abc import ABC
from threading import Thread

import torch_neuronx
from transformers import AutoConfig, LlamaTokenizer
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter
from transformers_neuronx.llama.model import LlamaForSampling

from ts.handler_utils.hf_batch_streamer import TextIteratorStreamerBatch
from ts.handler_utils.micro_batching import MicroBatching
from ts.protocol.otf_message_handler import send_intermediate_predict_response
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class LLMHandler(BaseHandler, ABC):
"""
Transformers handler class for text completion streaming on Inferentia2
"""

def __init__(self):
super().__init__()
self.initialized = False
self.max_length = None
self.tokenizer = None
self.output_streamer = None
# enable micro batching
self.handle = MicroBatching(self)

def initialize(self, ctx):
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
os.environ["NEURONX_CACHE"] = "on"
os.environ["NEURONX_DUMP_TO"] = f"{model_dir}/neuron_cache"
os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference"

# micro batching initialization
micro_batching_parallelism = ctx.model_yaml_config.get(
"micro_batching", {}
).get("parallelism", None)
if micro_batching_parallelism:
logger.info(
f"Setting micro batching parallelism from model_config_yaml: {micro_batching_parallelism}"
)
self.handle.parallelism = micro_batching_parallelism

micro_batch_size = ctx.model_yaml_config.get("micro_batching", {}).get(
"micro_batch_size", 1
)
logger.info(f"Setting micro batching size: {micro_batch_size}")
self.handle.micro_batch_size = micro_batch_size

# settings for model compiliation and loading
amp = ctx.model_yaml_config.get("handler", {}).get("amp", "f32")
tp_degree = ctx.model_yaml_config.get("handler", {}).get("tp_degree", 6)
self.max_length = ctx.model_yaml_config.get("handler", {}).get("max_length", 50)

# allocate "tp_degree" number of neuron cores to the worker process
os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree)
try:
num_neuron_cores_available = (
torch_neuronx.xla_impl.data_parallel.device_count()
)
assert num_neuron_cores_available >= int(tp_degree)
except (RuntimeError, AssertionError) as error:
namannandan marked this conversation as resolved.
Show resolved Hide resolved
logger.error(
"Required number of neuron cores for tp_degree "
+ str(tp_degree)
+ " are not available: "
+ str(error)
)

raise error

self.tokenizer = LlamaTokenizer.from_pretrained(model_dir)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = LlamaForSampling.from_pretrained(
model_dir,
batch_size=self.handle.micro_batch_size,
amp=amp,
tp_degree=tp_degree,
)
logger.info("Starting to compile the model")
self.model.to_neuron()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@namannandan I am wondering if compilation can be done a head of time and we just load the compiled graphs here the way it was working for inf1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested the _save_compiled_artifacts . It is able to generate a neuron model. However, the transformers_neuronx still needs to recompile. I already let Neuron team know they need more work on the experimental feature _save_compiled_artifacts.

logger.info("Model has been successfully compiled")
model_config = AutoConfig.from_pretrained(model_dir)
self.model = HuggingFaceGenerationModelAdapter(model_config, self.model)
self.output_streamer = TextIteratorStreamerBatch(
self.tokenizer,
batch_size=self.handle.micro_batch_size,
skip_special_tokens=True,
)

self.initialized = True

def preprocess(self, requests):
input_text = []
for req in requests:
data = req.get("data") or req.get("body")
if isinstance(data, (bytes, bytearray)):
data = data.decode("utf-8")
logger.info(f"received req={data}")
input_text.append(data.strip())
namannandan marked this conversation as resolved.
Show resolved Hide resolved

# Ensure the compiled model can handle the input received
if len(input_text) > self.handle.micro_batch_size:
raise ValueError(
f"Model is compiled for batch size {self.handle.micro_batch_size} but received input of size {len(input_text)}"
)

# Pad input to match compiled model batch size
input_text.extend([""] * (self.handle.micro_batch_size - len(input_text)))

return self.tokenizer(input_text, return_tensors="pt", padding=True)
namannandan marked this conversation as resolved.
Show resolved Hide resolved

def inference(self, tokenized_input):
generation_kwargs = dict(
tokenized_input,
streamer=self.output_streamer,
max_new_tokens=self.max_length,
)
self.model.reset_generation()
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()

micro_batch_idx = self.handle.get_micro_batch_idx()
micro_batch_req_id_map = self.get_micro_batch_req_id_map(micro_batch_idx)
for new_text in self.output_streamer:
logger.debug("send response stream")
send_intermediate_predict_response(
new_text[: len(micro_batch_req_id_map)],
micro_batch_req_id_map,
"Intermediate Prediction success",
200,
self.context,
)

thread.join()

return [""] * len(micro_batch_req_id_map)

def postprocess(self, inference_output):
return inference_output

def get_micro_batch_req_id_map(self, micro_batch_idx: int):
start_idx = micro_batch_idx * self.handle.micro_batch_size
micro_batch_req_id_map = {
index: self.context.request_ids[batch_index]
for index, batch_index in enumerate(
range(start_idx, start_idx + self.handle.micro_batch_size)
)
if batch_index in self.context.request_ids
}

return micro_batch_req_id_map
17 changes: 17 additions & 0 deletions examples/large_models/inferentia2/llama2/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 100
responseTimeout: 10800
batchSize: 16

handler:
amp: "bf16"
tp_degree: 6
max_length: 100

micro_batching:
namannandan marked this conversation as resolved.
Show resolved Hide resolved
micro_batch_size: 4
parallelism:
preprocess: 2
inference: 1
postprocess: 2
6 changes: 6 additions & 0 deletions examples/large_models/inferentia2/llama2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
--extra-index-url https://pip.repos.neuron.amazonaws.com
torch-neuronx==1.13.1.1.9.0
transformers-neuronx==0.5.58
transformers==4.31.0
namannandan marked this conversation as resolved.
Show resolved Hide resolved
tokenizers==0.13.3
sentencepiece==0.1.99
14 changes: 14 additions & 0 deletions examples/large_models/inferentia2/llama2/test_stream_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import requests

response = requests.post(
"http://localhost:8080/predictions/llama-2-13b",
data="Today the weather is really nice and I am planning on ",
stream=True,
)

for chunk in response.iter_content(chunk_size=None):
if chunk:
data = chunk.decode("utf-8")
print(data, end="", flush=True)

print("")
Loading