diff --git a/scripts/finetune.py b/scripts/finetune.py index da08fda0b..f3ce2aea8 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -178,9 +178,8 @@ def train( setup_wandb_env_vars(cfg) # load the tokenizer first - tokenizer_config = cfg.tokenizer_config or cfg.base_model_config - LOG.info(f"loading tokenizer... {tokenizer_config}") - tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg) + LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") + tokenizer = load_tokenizer(cfg) if ( check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2f672433d..5ebb0b64c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -32,32 +32,27 @@ from axolotl.utils.dict import DictDefault # noqa: F401 -def load_tokenizer( - tokenizer_config, - tokenizer_type, - cfg, -): +def load_tokenizer(cfg): tokenizer_kwargs = {} use_fast = True # this is the default + if cfg.tokenizer_use_fast is not None: use_fast = cfg.tokenizer_use_fast if cfg.tokenizer_legacy is not None: # True is the default w/ https://github.com/huggingface/transformers/pull/25224 tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy - if tokenizer_type: - tokenizer = getattr(transformers, tokenizer_type).from_pretrained( - tokenizer_config, - trust_remote_code=cfg.trust_remote_code or False, - use_fast=use_fast, - **tokenizer_kwargs, - ) - else: - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_config, - trust_remote_code=cfg.trust_remote_code or False, - use_fast=use_fast, - **tokenizer_kwargs, - ) + + tokenizer_cls = AutoTokenizer + if cfg.tokenizer_type: + tokenizer_cls = getattr(transformers, cfg.tokenizer_type) + + tokenizer_config = cfg.tokenizer_config or cfg.base_model_config + tokenizer = tokenizer_cls.from_pretrained( + tokenizer_config, + trust_remote_code=cfg.trust_remote_code or False, + use_fast=use_fast, + **tokenizer_kwargs, + ) if tokenizer.__class__.__name__ in [ "LlamaTokenizer", diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py index f2521e8e7..5c8339194 100644 --- a/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -13,17 +13,22 @@ class TestTokenizers(unittest.TestCase): """ def test_default_use_fast(self): - cfg = DictDefault({}) - tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg) + cfg = DictDefault( + { + "tokenizer_config": "huggyllama/llama-7b", + } + ) + tokenizer = load_tokenizer(cfg) assert "Fast" in tokenizer.__class__.__name__ def test_dont_use_fast(self): cfg = DictDefault( { + "tokenizer_config": "huggyllama/llama-7b", "tokenizer_use_fast": False, } ) - tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg) + tokenizer = load_tokenizer(cfg) assert "Fast" not in tokenizer.__class__.__name__