Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify load_tokenizer #375

Merged
merged 1 commit into from
Aug 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,8 @@ def train(
setup_wandb_env_vars(cfg)

# load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
LOG.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe normalize_config should set cfg.tokenizer_config so we don't need this or here?

tokenizer = load_tokenizer(cfg)

if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
Expand Down
33 changes: 14 additions & 19 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,27 @@
from axolotl.utils.dict import DictDefault # noqa: F401


def load_tokenizer(
tokenizer_config,
tokenizer_type,
cfg,
):
def load_tokenizer(cfg):
tokenizer_kwargs = {}
use_fast = True # this is the default

if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)

tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)

tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)

if tokenizer.__class__.__name__ in [
"LlamaTokenizer",
Expand Down
11 changes: 8 additions & 3 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@ class TestTokenizers(unittest.TestCase):
"""

def test_default_use_fast(self):
cfg = DictDefault({})
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
}
)
tokenizer = load_tokenizer(cfg)
assert "Fast" in tokenizer.__class__.__name__

def test_dont_use_fast(self):
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"tokenizer_use_fast": False,
}
)
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
tokenizer = load_tokenizer(cfg)
assert "Fast" not in tokenizer.__class__.__name__


Expand Down