Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…o app_dev
  • Loading branch information
souradipp76 committed Nov 17, 2024
2 parents 503a2be + 26a75f1 commit d45e0d7
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 30 deletions.
2 changes: 1 addition & 1 deletion readme_ready/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.0
1.1.1
68 changes: 43 additions & 25 deletions readme_ready/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@

def get_gemma_chat_model(model_name: str, streaming=False, model_kwargs=None):
"""Get GEMMA Chat Model"""
bnb_config = None
gguf_file = None
if "gguf_file" in model_kwargs and model_kwargs["gguf_file"] is not None:
gguf_file = model_kwargs["gguf_file"]
_ = hf_hub_download(model_name, gguf_file)
tokenizer = get_tokenizer(model_name, gguf_file)
if sys.platform == "linux" or sys.platform == "linux2":
from transformers import BitsAndBytesConfig

Expand All @@ -27,17 +31,22 @@ def get_gemma_chat_model(model_name: str, streaming=False, model_kwargs=None):
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
gguf_file = model_kwargs["gguf_file"]
_ = hf_hub_download(model_name, gguf_file)
tokenizer = get_tokenizer(model_name, gguf_file)
model = AutoModelForCausalLM.from_pretrained(
model_name,
gguf_file=gguf_file,
trust_remote_code=True,
device_map=model_kwargs["device"],
quantization_config=bnb_config,
token=os.environ["HF_TOKEN"],
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
gguf_file=gguf_file,
trust_remote_code=True,
device_map=model_kwargs["device"],
quantization_config=bnb_config,
token=os.environ["HF_TOKEN"],
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
gguf_file=gguf_file,
trust_remote_code=True,
device_map=model_kwargs["device"],
token=os.environ["HF_TOKEN"],
)

if (
"peft_model_path" in model_kwargs
Expand All @@ -61,7 +70,12 @@ def get_gemma_chat_model(model_name: str, streaming=False, model_kwargs=None):

def get_llama_chat_model(model_name: str, streaming=False, model_kwargs=None):
"""Get LLAMA2 Chat Model"""
bnb_config = None
gguf_file = None
if "gguf_file" in model_kwargs and model_kwargs["gguf_file"] is not None:
gguf_file = model_kwargs["gguf_file"]
_ = hf_hub_download(model_name, gguf_file)
tokenizer = get_tokenizer(model_name, gguf_file)
tokenizer.pad_token = tokenizer.eos_token
if sys.platform == "linux" or sys.platform == "linux2":
from transformers import BitsAndBytesConfig

Expand All @@ -73,18 +87,22 @@ def get_llama_chat_model(model_name: str, streaming=False, model_kwargs=None):
# use_exllama=False,
# exllama_config={"version": 2}
)

gguf_file = model_kwargs["gguf_file"]
_ = hf_hub_download(model_name, gguf_file)
tokenizer = get_tokenizer(model_name, gguf_file)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
gguf_file=gguf_file,
trust_remote_code=True,
device_map=model_kwargs["device"],
quantization_config=bnb_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
gguf_file=gguf_file,
trust_remote_code=True,
device_map=model_kwargs["device"],
quantization_config=bnb_config,
token=os.environ["HF_TOKEN"],
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
gguf_file=gguf_file,
trust_remote_code=True,
device_map=model_kwargs["device"],
token=os.environ["HF_TOKEN"],
)

if "peft_model" in model_kwargs and model_kwargs["peft_model"] is not None:
PEFT_MODEL = model_kwargs["peft_model"]
Expand Down
7 changes: 3 additions & 4 deletions tests/utils/test_llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def test_get_gemma_chat_model_with_peft():
gguf_file=model_kwargs["gguf_file"],
trust_remote_code=True,
device_map=model_kwargs["device"],
quantization_config=mock.ANY,
token="test_token",
)
mock_peft_model.assert_called_once_with(
Expand Down Expand Up @@ -129,7 +128,6 @@ def test_get_gemma_chat_model_without_peft():
gguf_file=model_kwargs["gguf_file"],
trust_remote_code=True,
device_map=model_kwargs["device"],
quantization_config=mock.ANY,
token="test_token",
)
mock_peft_model.assert_not_called()
Expand Down Expand Up @@ -254,7 +252,7 @@ def test_get_llama_chat_model_with_peft():
gguf_file=model_kwargs["gguf_file"],
trust_remote_code=True,
device_map=model_kwargs["device"],
quantization_config=mock.ANY,
token="test_token",
)
mock_peft_model.assert_called_once_with(
mock_model, model_kwargs["peft_model"]
Expand Down Expand Up @@ -316,7 +314,7 @@ def test_get_llama_chat_model_without_peft():
gguf_file=model_kwargs["gguf_file"],
trust_remote_code=True,
device_map=model_kwargs["device"],
quantization_config=mock.ANY,
token="test_token",
)
mock_peft_model.assert_not_called()
mock_pipeline.assert_called_once()
Expand Down Expand Up @@ -377,6 +375,7 @@ def test_get_llama_chat_model_with_bnb_config():
trust_remote_code=True,
device_map=model_kwargs["device"],
quantization_config=mock.ANY,
token="test_token",
)
mock_peft_model.assert_not_called()
mock_pipeline.assert_called_once()
Expand Down

0 comments on commit d45e0d7

Please sign in to comment.