Skip to content

Commit fa5d2d4

Browse files
authored
[AutoMM] Support customizing use_fast for AutoTokenizer (open-mmlab#3379)
1 parent 03cc58c commit fa5d2d4

File tree

5 files changed

+82
-1
lines changed

5 files changed

+82
-1
lines changed

multimodal/src/autogluon/multimodal/configs/model/fusion_mlp_image_text_tabular.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ model:
6262
data_types:
6363
- "text"
6464
tokenizer_name: "hf_auto"
65+
use_fast: True # Use a fast Rust-based tokenizer if it is supported for a given model. If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead.
6566
max_text_len: 512 # If None or <=0, then use the max length of pretrained models.
6667
insert_sep: True
6768
low_cpu_mem_usage: False

multimodal/src/autogluon/multimodal/data/process_text.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
train_augment_types: Optional[List[str]] = None,
9696
template_config: Optional[DictConfig] = None,
9797
normalize_text: Optional[bool] = False,
98+
use_fast: Optional[bool] = True,
9899
):
99100
"""
100101
Parameters
@@ -125,6 +126,11 @@ def __init__(
125126
Whether to normalize text to resolve encoding problems.
126127
Examples of normalized texts can be found at
127128
https://github.com/autogluon/autogluon/tree/master/examples/automm/kaggle_feedback_prize#15-a-few-examples-of-normalized-texts
129+
use_fast
130+
Use a fast Rust-based tokenizer if it is supported for a given model.
131+
If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead.
132+
See: https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoTokenizer.from_pretrained.use_fast
133+
128134
"""
129135
self.prefix = model.prefix
130136
self.tokenizer_name = tokenizer_name
@@ -136,6 +142,7 @@ def __init__(
136142
self.tokenizer = self.get_pretrained_tokenizer(
137143
tokenizer_name=tokenizer_name,
138144
checkpoint_name=model.checkpoint_name,
145+
use_fast=use_fast,
139146
)
140147
if hasattr(self.tokenizer, "deprecation_warnings"):
141148
# Disable the warning "Token indices sequence length is longer than the specified maximum sequence..."
@@ -410,6 +417,7 @@ def get_special_tokens(tokenizer):
410417
def get_pretrained_tokenizer(
411418
tokenizer_name: str,
412419
checkpoint_name: str,
420+
use_fast: Optional[bool] = True,
413421
):
414422
"""
415423
Load the tokenizer for a pre-trained huggingface checkpoint.
@@ -420,14 +428,18 @@ def get_pretrained_tokenizer(
420428
The tokenizer type, e.g., "bert", "clip", "electra", and "hf_auto".
421429
checkpoint_name
422430
Name of a pre-trained checkpoint.
431+
use_fast
432+
Use a fast Rust-based tokenizer if it is supported for a given model.
433+
If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead.
434+
See: https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoTokenizer.from_pretrained.use_fast
423435
424436
Returns
425437
-------
426438
A tokenizer instance.
427439
"""
428440
try:
429441
tokenizer_class = ALL_TOKENIZERS[tokenizer_name]
430-
return tokenizer_class.from_pretrained(checkpoint_name)
442+
return tokenizer_class.from_pretrained(checkpoint_name, use_fast=use_fast)
431443
except TypeError as e:
432444
try:
433445
tokenizer_class = ALL_TOKENIZERS["bert"]

multimodal/src/autogluon/multimodal/utils/data.py

+1
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def create_data_processor(
157157
train_augment_types=OmegaConf.select(model_config, "text_train_augment_types"),
158158
template_config=getattr(config.data, "templates", OmegaConf.create({"turn_on": False})),
159159
normalize_text=getattr(config.data.text, "normalize_text", False),
160+
use_fast=OmegaConf.select(model_config, "use_fast", default=True),
160161
)
161162
elif data_type == CATEGORICAL:
162163
data_processor = CategoricalProcessor(

multimodal/tests/hf_model_list.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ others_2:
4040
- t5-small
4141
- microsoft/layoutlmv3-base
4242
- microsoft/layoutlmv2-base-uncased
43+
- albert-base-v2
4344
predictor:
4445
- CLTL/MedRoBERTa.nl
4546
- google/electra-small-discriminator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import os
2+
import shutil
3+
import tempfile
4+
5+
import pytest
6+
from transformers import AlbertTokenizer, AlbertTokenizerFast
7+
8+
from autogluon.multimodal import MultiModalPredictor
9+
from autogluon.multimodal.constants import TEXT
10+
11+
from ..utils.unittest_datasets import AEDataset, HatefulMeMesDataset, IDChangeDetectionDataset, PetFinderDataset
12+
13+
ALL_DATASETS = {
14+
"petfinder": PetFinderDataset,
15+
"hateful_memes": HatefulMeMesDataset,
16+
"ae": AEDataset,
17+
}
18+
19+
20+
@pytest.mark.parametrize(
21+
"checkpoint_name,use_fast,tokenizer_type",
22+
[
23+
(
24+
"albert-base-v2",
25+
None,
26+
AlbertTokenizerFast,
27+
),
28+
(
29+
"albert-base-v2",
30+
True,
31+
AlbertTokenizerFast,
32+
),
33+
(
34+
"albert-base-v2",
35+
False,
36+
AlbertTokenizer,
37+
),
38+
],
39+
)
40+
def test_tokenizer_use_fast(checkpoint_name, use_fast, tokenizer_type):
41+
dataset = ALL_DATASETS["ae"]()
42+
metric_name = dataset.metric
43+
44+
predictor = MultiModalPredictor(
45+
label=dataset.label_columns[0],
46+
problem_type=dataset.problem_type,
47+
eval_metric=metric_name,
48+
)
49+
hyperparameters = {
50+
"data.categorical.convert_to_text": True,
51+
"data.numerical.convert_to_text": True,
52+
"model.hf_text.checkpoint_name": checkpoint_name,
53+
}
54+
if use_fast is not None:
55+
hyperparameters["model.hf_text.use_fast"] = use_fast
56+
57+
with tempfile.TemporaryDirectory() as save_path:
58+
if os.path.isdir(save_path):
59+
shutil.rmtree(save_path)
60+
predictor.fit(
61+
train_data=dataset.train_df,
62+
time_limit=5,
63+
save_path=save_path,
64+
hyperparameters=hyperparameters,
65+
)
66+
assert isinstance(predictor._data_processors[TEXT][0].tokenizer, tokenizer_type)

0 commit comments

Comments
 (0)