Skip to content

Commit

Permalink
Change Llama2 from the Turbine implementation to the Sharktank one
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetters-amd committed Sep 19, 2024
1 parent ec75e08 commit 8359038
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 34 deletions.
79 changes: 54 additions & 25 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from turbine_models.custom_models import stateless_llama
from shark_turbine.aot import *
from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
import sharktank
import huggingface_hub

# from turbine_models.custom_models import stateless_llama
from turbine_models.model_runner import vmfbRunner
from turbine_models.gen_external_params.gen_external_params import gen_external_params
import time
Expand All @@ -19,23 +24,23 @@

llm_model_map = {
"meta-llama/Llama-2-7b-chat-hf": {
"initializer": stateless_llama.export_transformer_model,
# "initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"Trelis/Llama-2-7b-chat-hf-function-calling-v2": {
"initializer": stateless_llama.export_transformer_model,
# "initializer": stateless_llama.export_transformer_model,
"hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"compile_flags": ["--iree-opt-const-expr-hoisting=False"],
"stop_token": 2,
"max_tokens": 4096,
"system_prompt": """<s>[INST] <<SYS>>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>>""",
},
"TinyPixel/small-llama2": {
"initializer": stateless_llama.export_transformer_model,
# "initializer": stateless_llama.export_transformer_model,
"hf_model_name": "TinyPixel/small-llama2",
"compile_flags": ["--iree-opt-const-expr-hoisting=True"],
"stop_token": 2,
Expand Down Expand Up @@ -130,13 +135,18 @@ def __init__(
print(
f"External weight file {self.external_weight_file} does not exist. Generating..."
)
gen_external_params(
hf_model_name=self.hf_model_name,
quantization=self.quantization,
weight_path=self.external_weight_file,
hf_auth_token=hf_auth_token,
precision=self.precision,
# gen_external_params(
# hf_model_name=self.hf_model_name,
# quantization=self.quantization,
# weight_path=self.external_weight_file,
# hf_auth_token=hf_auth_token,
# precision=self.precision,
# )
cache_dir = os.path.join(".", str(self.hf_model_name).replace("/", "_"))
huggingface_hub.snapshot_download(
repo_id=self.hf_model_name, cache_dir=cache_dir
)
# TODO: Convert to gguf, delete cache
else:
print(
f"External weight file {self.external_weight_file} found for {self.vmfb_name}"
Expand All @@ -161,20 +171,39 @@ def __init__(
use_auth_token=hf_auth_token,
)
elif not os.path.exists(self.tempfile_name):
self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
"initializer"
](
self.hf_model_name,
hf_auth_token,
compile_to="torch",
external_weights=external_weights,
precision=self.precision,
quantization=self.quantization,
streaming_llm=self.streaming_llm,
decomp_attn=True,
# self.torch_ir, self.tokenizer = llm_model_map[self.hf_model_name][
# "initializer"
# ](
# self.hf_model_name,
# hf_auth_token,
# compile_to="torch",
# external_weights=external_weights,
# precision=self.precision,
# quantization=self.quantization,
# streaming_llm=self.streaming_llm,
# decomp_attn=True,
# )

dataset = sharktank.types.Dataset.load(
self.external_weight_file, file_type="gguf"
)
hp = sharktank.layers.configs.LlamaHParams.from_gguf_props(
dataset.properties
)
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
llama_config = sharktank.models.llama.llama.LlamaModelConfig(hp)
llama_config.use_hf = False
llama_config.static_tables = (
False # Rely on the compiler for hoisting tables.
)
llama_config.kv_cache_type = "direct" # if args.bs == [1] else "paged"
model = PagedLlamaModelV1(dataset.root_theta, llama_config)

fxb = FxProgramsBuilder(model)
self.torch_ir = export(fxb)
self.torch_ir.save_mlir(self.tempfile_name)

# with open(self.tempfile_name, "w+") as f:
# f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.compile()
Expand Down Expand Up @@ -413,7 +442,7 @@ def llm_chat_api(InputData: dict):
hf_auth_token=cmd_opts.hf_auth_token,
device=device,
quantization=cmd_opts.quantization,
external_weights="safetensors",
external_weights="gguf",
use_system_prompt=True,
streaming_llm=False,
)
Expand Down Expand Up @@ -467,7 +496,7 @@ def llm_chat_api(InputData: dict):
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
hf_auth_token=None,
device="cpu-task",
external_weights="safetensors",
external_weights="gguf",
)

print("model loaded")
Expand Down
18 changes: 10 additions & 8 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,14 @@ def __init__(
batch_size=batch_size,
num_inference_steps=steps,
device=target_backend,
iree_target_triple=triple,
target=triple, # iree_target_triple=triple,
ireec_flags=EMPTY_FLAGS,
attn_spec=attn_spec,
decomp_attn=decomp_attn,
pipeline_dir=self.pipeline_dir,
external_weights_dir=self.weights_path,
external_weights=external_weights,
custom_vae=custom_vae,
# custom_vae=custom_vae,
)
print(f"\n[LOG] Pipeline initialized with pipe_id: {self.pipe_id}.")
gc.collect()
Expand Down Expand Up @@ -237,13 +237,15 @@ def prepare_pipe(
)
weights[key] = save_irpa(vae_weights_path, "vae.")

vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
)
# vmfbs, weights = self.sd_pipe.check_prepared(
# mlirs, vmfbs, weights, interactive=False
# )
self.sd_pipe.prepare_all()
print(f"\n[LOG] Loading pipeline to device {self.rt_device}.")
self.sd_pipe.load_pipeline(
vmfbs, weights, self.rt_device, self.compiled_pipeline
)
# self.sd_pipe.load_pipeline(
# vmfbs, weights, self.rt_device, self.compiled_pipeline
# )
self.sd_pipe.load_map()
print(
"\n[LOG] Pipeline successfully prepared for runtime. Generating images..."
)
Expand Down
2 changes: 1 addition & 1 deletion apps/shark_studio/web/ui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def chat_fn(
model,
device=device,
precision=precision,
external_weights="safetensors",
external_weights="gguf",
use_system_prompt=prompt_prefix,
streaming_llm=streaming_llm,
hf_auth_token=cmd_opts.hf_auth_token,
Expand Down
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ py-cpuinfo
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
mpmath==1.3.0
optimum
fastapi<0.113.0

# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller

sharktank
gguf
huggingface_hub

0 comments on commit 8359038

Please sign in to comment.