Skip to content

Commit

Permalink
Update llama3 example. (#743)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 authored Apr 18, 2024
1 parent a3900a4 commit 130993e
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions examples/llama3-8b-ec2/llama3_ec2.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# # Deploy Llama2 13B Chat Model Inference on AWS EC2
# # Deploy Llama3 8B Chat Model Inference on AWS EC2

# This example demonstrates how to deploy a
# [LLama2 13B model from Hugging Face](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf)
# [LLama3 8B model from Hugging Face](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
# on AWS EC2 using Runhouse.
#
# Make sure to sign the waiver on the model page so that you can access it.
#
# ## Setup credentials and dependencies
#
# Optionally, set up a virtual environment:
# ```shell
# $ conda create -n llama-demo-apps python=3.8
# $ conda activate llama-demo-apps
# $ conda create -n llama3-rh
# $ conda activate llama3-rh
# ```
# Install the few required dependencies:
# ```shell
Expand All @@ -22,7 +24,7 @@
# $ aws configure
# $ sky check
# ```
# We'll be downloading the Llama2 model from Hugging Face, so we need to set up our Hugging Face token:
# We'll be downloading the Llama3 model from Hugging Face, so we need to set up our Hugging Face token:
# ```shell
# $ export HF_TOKEN=<your huggingface token>
# ```
Expand Down Expand Up @@ -55,7 +57,7 @@ def load_model(self):
self.pipeline = transformers.pipeline(
"text-generation",
model=self.model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
model_kwargs=self.model_kwargs,
device="cuda",
)

Expand Down Expand Up @@ -108,7 +110,6 @@ def predict(self, prompt_text, **inf_kwargs):
gpu = rh.cluster(
name="rh-a10x", instance_type="A10G:1", memory="32+", provider="aws"
)
# gpu.restart_server(restart_ray=True)

# Next, we define the environment for our module. This includes the required dependencies that need
# to be installed on the remote machine, as well as any secrets that need to be synced up from local to remote.
Expand All @@ -124,38 +125,34 @@ def predict(self, prompt_text, **inf_kwargs):
"safetensors",
"scipy",
],
secrets=["huggingface"], # Needed to download Llama2
name="llama2inference",
secrets=["huggingface"], # Needed to download Llama3 from HuggingFace
name="llama3inference",
working_dir="./",
)

# Finally, we define our module and run it on the remote cluster. We construct it normally and then call
# `get_or_to` to run it on the remote cluster. Using `get_or_to` allows us to load the exiting Module
# by the name `llama-13b-model` if it was already put on the cluster. If we want to update the module each
# by the name `llama3-8b-model` if it was already put on the cluster. If we want to update the module each
# time we run this script, we can use `to` instead of `get_or_to`.
#
# Note that we also pass the `env` object to the `get_or_to` method, which will ensure that the environment is
# set up on the remote machine before the module is run.
remote_hf_chat_model = HFChatModel(
load_in_4bit=True, # Ignored right now
torch_dtype=torch.bfloat16, # Ignored right now
device_map="auto", # Ignored right now
).get_or_to(gpu, env=env, name="llama-13b-model")
torch_dtype=torch.bfloat16,
).get_or_to(gpu, env=env, name="llama3-8b-model")

# ## Calling our remote function
#
# We can call the `predict` method on the model class instance if it were running locally.
# This will run the function on the remote cluster and return the response to our local machine automatically.
# Further calls will also run on the remote machine, and maintain state that was updated between calls, like
# `self.model` and `self.tokenizer`.
prompt = "Who are you?"
print(remote_hf_chat_model.predict(prompt))
# while True:
# prompt = input(
# "\n\n... Enter a prompt to chat with the model, and 'exit' to exit ...\n"
# )
# if prompt.lower() == "exit":
# break
# output = remote_hf_chat_model.predict(prompt)
# print("\n\n... Model Output ...\n")
# print(output)
# `self.pipeline`.
while True:
prompt = input(
"\n\n... Enter a prompt to chat with the model, and 'exit' to exit ...\n"
)
if prompt.lower().strip() == "exit":
break
output = remote_hf_chat_model.predict(prompt)
print("\n\n... Model Output ...\n")
print(output)

0 comments on commit 130993e

Please sign in to comment.