Skip to content

Commit

Permalink
Remove hf_auth_token use
Browse files Browse the repository at this point in the history
-- This commit removes `--hf_auth_token` uses from vicuna.py.
-- It adds llama2 models based on daryl49's HF.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
  • Loading branch information
Abhishek-Varma committed Sep 8, 2023
1 parent bde63ee commit 1745c8f
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 56 deletions.
41 changes: 6 additions & 35 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,6 @@
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
parser.add_argument(
"--cache_vicunas",
default=False,
Expand Down Expand Up @@ -460,10 +454,6 @@ def __init__(

def get_tokenizer(self):
kwargs = {}
if self.model_name == "llama2":
kwargs = {
"use_auth_token": "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"
}
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
use_fast=False,
Expand Down Expand Up @@ -1217,7 +1207,6 @@ def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
hf_auth_token: str = None,
max_num_tokens=512,
device="cpu",
precision="int8",
Expand All @@ -1237,17 +1226,12 @@ def __init__(
max_num_tokens,
extra_args_cmd=extra_args_cmd,
)
if "llama2" in self.model_name and hf_auth_token == None:
raise ValueError(
"HF auth token required. Pass it using --hf_auth_token flag."
)
self.hf_auth_token = hf_auth_token
if self.model_name == "llama2_7b":
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
self.hf_model_path = "daryl149/llama-2-7b-chat-hf"
elif self.model_name == "llama2_13b":
self.hf_model_path = "meta-llama/Llama-2-13b-chat-hf"
self.hf_model_path = "daryl149/llama-2-13b-chat-hf"
elif self.model_name == "llama2_70b":
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
self.hf_model_path = "daryl149/llama-2-70b-chat-hf"
print(f"[DEBUG] hf model name: {self.hf_model_path}")
self.max_sequence_length = 256
self.device = device
Expand Down Expand Up @@ -1276,18 +1260,15 @@ def get_model_path(self, suffix="mlir"):
)

def get_tokenizer(self):
kwargs = {"use_auth_token": self.hf_auth_token}
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
use_fast=False,
**kwargs,
)
return tokenizer

def get_src_model(self):
kwargs = {
"torch_dtype": torch.float,
"use_auth_token": self.hf_auth_token,
}
vicuna_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path,
Expand Down Expand Up @@ -1460,8 +1441,6 @@ def compile(self):
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
Expand Down Expand Up @@ -1553,24 +1532,18 @@ def compile(self):
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
elif self.model_name == "llama2_70b":
model = SecondVicuna70B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
else:
model = SecondVicuna7B(
self.hf_model_path,
self.precision,
self.weight_group_size,
self.model_name,
self.hf_auth_token,
)
print(f"[DEBUG] generating torchscript graph")
is_f16 = self.precision in ["fp16", "int4"]
Expand Down Expand Up @@ -1714,7 +1687,6 @@ def generate(self, prompt, cli):
logits = generated_token_op["logits"]
pkv = generated_token_op["past_key_values"]
detok = generated_token_op["detok"]

if token == 2:
break
res_tokens.append(token)
Expand Down Expand Up @@ -1809,7 +1781,6 @@ def create_prompt(model_name, history):
)
vic = UnshardedVicuna(
model_name=args.model_name,
hf_auth_token=args.hf_auth_token,
device=args.device,
precision=args.precision,
vicuna_mlir_path=vic_mlir_path,
Expand Down Expand Up @@ -1851,9 +1822,9 @@ def create_prompt(model_name, history):

model_list = {
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
"llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf",
"llama2_13b": "llama2_13b=>meta-llama/Llama-2-13b-chat-hf",
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
"llama2_7b": "llama2_7b=>daryl149/llama-2-7b-chat-hf",
"llama2_13b": "llama2_7b=>daryl149/llama-2-13b-chat-hf",
"llama2_70b": "llama2_7b=>daryl149/llama-2-70b-chat-hf",
}
while True:
# TODO: Add break condition from user input
Expand Down
16 changes: 0 additions & 16 deletions apps/language_models/src/model_wrappers/vicuna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@ def __init__(
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
Expand Down Expand Up @@ -57,13 +53,9 @@ def __init__(
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
Expand Down Expand Up @@ -303,13 +295,9 @@ def __init__(
model_path,
precision="int8",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
Expand Down Expand Up @@ -596,13 +584,9 @@ def __init__(
model_path,
precision="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
Expand Down
7 changes: 3 additions & 4 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def user(message, history):
past_key_values = None

model_map = {
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
"llama2_7b": "daryl149/llama-2-7b-chat-hf",
"llama2_13b": "daryl149/llama-2-13b-chat-hf",
"llama2_70b": "daryl149/llama-2-70b-chat-hf",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
}

Expand Down Expand Up @@ -186,7 +186,6 @@ def chat(
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
precision=precision,
max_num_tokens=max_toks,
Expand Down
1 change: 0 additions & 1 deletion shark/iree_utils/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
):
print(f"Loading module {flatbuffer_blob_or_path}...")
if "rocm" in device:
device = "rocm"
with DetailLogger(timeout=2.5) as dl:
Expand Down

0 comments on commit 1745c8f

Please sign in to comment.