Skip to content

Commit

Permalink
fix relative path for fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian authored and NanoCode012 committed May 30, 2023
1 parent a1f9850 commit cfcc549
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def load_model(
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_compute_dtype=torch_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
Expand Down Expand Up @@ -280,8 +281,8 @@ def load_model(
# llama is PROBABLY model parallelizable, but the default isn't that it is
# so let's only set it for the 4bit, see
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
setattr(model, 'is_parallelizable', True)
setattr(model, 'model_parallel', True)
setattr(model, "is_parallelizable", True)
setattr(model, "model_parallel", True)

requires_grad = []
for name, param in model.named_parameters(recurse=True):
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
output_dir=cfg.output_dir,
save_total_limit=3,
load_best_model_at_end=(
cfg.val_set_size > 0
cfg.load_best_model_at_end is not False
and cfg.val_set_size > 0
and save_steps
and save_steps % eval_steps == 0
and cfg.load_in_8bit is not True
Expand Down
15 changes: 13 additions & 2 deletions tests/test_prompt_tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Module for testing prompt tokenizers."""
import json
import logging
import unittest

from pathlib import Path

from transformers import AutoTokenizer
Expand All @@ -12,6 +14,10 @@


class TestPromptTokenizationStrategies(unittest.TestCase):
"""
Test class for prompt tokenization strategies.
"""

def setUp(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
Expand All @@ -24,10 +30,15 @@ def setUp(self) -> None:

def test_sharegpt_integration(self):
print(Path(__file__).parent)
with open(Path(__file__).parent / "fixtures/conversation.json", "r") as fin:
with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin:
data = fin.read()
conversation = json.loads(data)
with open(Path(__file__).parent / "fixtures/conversation.tokenized.json", "r") as fin:
with open(
Path(__file__).parent / "fixtures/conversation.tokenized.json",
encoding="utf-8",
) as fin:
data = fin.read()
tokenized_conversation = json.loads(data)
prompter = ShareGPTPrompter("chat")
Expand Down

0 comments on commit cfcc549

Please sign in to comment.