Skip to content

Commit

Permalink
[Llama2] Prefetch llama2 tokenizer configs
Browse files Browse the repository at this point in the history
-- This commit prefetches llama2 tokenizer configs from shark_tank.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
  • Loading branch information
Abhishek-Varma committed Sep 8, 2023
1 parent 9681d49 commit 3acdaf9
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 45 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,6 @@ db_dir_UserData

# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/

# Llama2 tokenizer configs
llama2_tokenizer_configs/
43 changes: 15 additions & 28 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,6 @@
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
parser.add_argument(
"--cache_vicunas",
default=False,
Expand Down Expand Up @@ -1217,7 +1211,6 @@ def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
hf_auth_token: str = None,
max_num_tokens=512,
device="cpu",
precision="int8",
Expand All @@ -1237,11 +1230,6 @@ def __init__(
max_num_tokens,
extra_args_cmd=extra_args_cmd,
)
if "llama2" in self.model_name and hf_auth_token == None:
raise ValueError(
"HF auth token required. Pass it using --hf_auth_token flag."
)
self.hf_auth_token = hf_auth_token
if self.model_name == "llama2_7b":
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
elif self.model_name == "llama2_13b":
Expand Down Expand Up @@ -1276,18 +1264,26 @@ def get_model_path(self, suffix="mlir"):
)

def get_tokenizer(self):
kwargs = {"use_auth_token": self.hf_auth_token}
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
use_fast=False,
**kwargs,
)
local_tokenizer_path = Path(Path.cwd(), "llama2_tokenizer_configs")
local_tokenizer_path.mkdir(parents=True, exist_ok=True)
tokenizer_files_to_download = [
"config.json",
"special_tokens_map.json",
"tokenizer.model",
"tokenizer_config.json",
]
for tokenizer_file in tokenizer_files_to_download:
download_public_file(
f"gs://shark_tank/llama2_tokenizer/{tokenizer_file}",
Path(local_tokenizer_path, tokenizer_file),
single_file=True,
)
tokenizer = AutoTokenizer.from_pretrained(str(local_tokenizer_path))
return tokenizer

def get_src_model(self):
kwargs = {
"torch_dtype": torch.float,
"use_auth_token": self.hf_auth_token,
}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path,
Expand Down Expand Up @@ -1460,8 +1456,6 @@ def compile(self):
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
Expand Down Expand Up @@ -1553,24 +1547,18 @@ def compile(self):
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
elif self.model_name == "llama2_70b":
model = SecondVicuna70B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
Expand Down Expand Up @@ -1809,7 +1797,6 @@ def create_prompt(model_name, history):
)
vic = UnshardedVicuna(
model_name=args.model_name,
hf_auth_token=args.hf_auth_token,
device=args.device,
precision=args.precision,
vicuna_mlir_path=vic_mlir_path,
Expand Down
16 changes: 0 additions & 16 deletions apps/language_models/src/model_wrappers/vicuna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@ def __init__(
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
Expand Down Expand Up @@ -57,13 +53,9 @@ def __init__(
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
Expand Down Expand Up @@ -303,13 +295,9 @@ def __init__(
model_path,
precision="int8",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
Expand Down Expand Up @@ -596,13 +584,9 @@ def __init__(
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
Expand Down
1 change: 0 additions & 1 deletion apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def chat(
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,
Expand Down

0 comments on commit 3acdaf9

Please sign in to comment.