-
Notifications
You must be signed in to change notification settings - Fork 866
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add llama2 text completion example with streaming response support
- Loading branch information
Naman Nandan
committed
Jul 27, 2023
1 parent
71e1c2e
commit b9f2654
Showing
8 changed files
with
136 additions
and
189 deletions.
There are no files selected for viewing
167 changes: 0 additions & 167 deletions
167
examples/large_models/inferentia2/llama/inf2_handler.py
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
100 changes: 100 additions & 0 deletions
100
examples/large_models/inferentia2/llama2/inf2_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import logging | ||
import os | ||
from abc import ABC | ||
from threading import Thread | ||
|
||
import torch_neuronx | ||
from transformers import AutoConfig, LlamaTokenizer, TextIteratorStreamer | ||
from transformers_neuronx.generation_utils import HuggingFaceGenerationModelAdapter | ||
from transformers_neuronx.llama.model import LlamaForSampling | ||
|
||
from ts.protocol.otf_message_handler import send_intermediate_predict_response | ||
from ts.torch_handler.base_handler import BaseHandler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LLMHandler(BaseHandler, ABC): | ||
""" | ||
Transformers handler class for text completion streaming on Inferentia2 | ||
""" | ||
|
||
def __init__(self): | ||
super(LLMHandler, self).__init__() | ||
self.initialized = False | ||
|
||
def initialize(self, ctx): | ||
self.manifest = ctx.manifest | ||
properties = ctx.system_properties | ||
model_dir = properties.get("model_dir") | ||
|
||
# settings for model compiliation and loading | ||
model_name = ctx.model_yaml_config["handler"]["model_name"] | ||
tp_degree = ctx.model_yaml_config["handler"]["tp_degree"] | ||
self.max_length = ctx.model_yaml_config["handler"]["max_length"] | ||
|
||
# allocate "tp_degree" number of neuron cores to the worker process | ||
os.environ["NEURON_RT_NUM_CORES"] = str(tp_degree) | ||
try: | ||
num_neuron_cores_available = ( | ||
torch_neuronx.xla_impl.data_parallel.device_count() | ||
) | ||
assert num_neuron_cores_available >= int(tp_degree) | ||
except (RuntimeError, AssertionError) as error: | ||
raise RuntimeError( | ||
"Required number of neuron cores for tp_degree " | ||
+ str(tp_degree) | ||
+ " are not available: " | ||
+ str(error) | ||
) | ||
|
||
os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference" | ||
|
||
self.tokenizer = LlamaTokenizer.from_pretrained(model_name) | ||
self.model = LlamaForSampling.from_pretrained( | ||
model_dir, batch_size=1, tp_degree=tp_degree | ||
) | ||
logger.info("Starting to compile the model") | ||
self.model.to_neuron() | ||
logger.info("Model has been successfully compiled") | ||
model_config = AutoConfig.from_pretrained(model_dir) | ||
self.model = HuggingFaceGenerationModelAdapter(model_config, self.model) | ||
self.output_streamer = TextIteratorStreamer(self.tokenizer) | ||
|
||
self.initialized = True | ||
|
||
def preprocess(self, requests): | ||
input_texts = [] | ||
for req in requests: | ||
data = req.get("data") or req.get("body") | ||
if isinstance(data, (bytes, bytearray)): | ||
data = data.decode("utf-8") | ||
input_texts.append(data) | ||
|
||
return self.tokenizer(input_texts, return_tensors="pt") | ||
|
||
def inference(self, tokenized_input): | ||
generation_kwargs = dict( | ||
tokenized_input, | ||
streamer=self.output_streamer, | ||
max_new_tokens=self.max_length, | ||
) | ||
self.model.reset_generation() | ||
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | ||
thread.start() | ||
|
||
for new_text in self.output_streamer: | ||
send_intermediate_predict_response( | ||
[new_text], | ||
self.context.request_ids, | ||
"Intermediate Prediction success", | ||
200, | ||
self.context, | ||
) | ||
|
||
thread.join() | ||
|
||
return [""] | ||
|
||
def postprocess(self, inference_output): | ||
return inference_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
torch-neuronx | ||
transformers-neuronx | ||
transformers | ||
tokenizers | ||
sentencepiece |
14 changes: 14 additions & 0 deletions
14
examples/large_models/inferentia2/llama2/test_stream_response.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import requests | ||
|
||
response = requests.post( | ||
"http://localhost:8080/predictions/llama-2-7b", | ||
data="Today the weather is really nice and I am planning on ", | ||
stream=True, | ||
) | ||
|
||
for chunk in response.iter_content(chunk_size=None): | ||
if chunk: | ||
data = chunk.decode("utf-8") | ||
print(data, end="", flush=True) | ||
|
||
print("") |