Skip to content

Commit

Permalink
Added ngram tokenizer (#2723)
Browse files Browse the repository at this point in the history
Co-authored-by: Geoffrey Angus <geoffrey@predibase.com>
  • Loading branch information
tgaddair and geoffreyangus authored Nov 4, 2022
1 parent e85930d commit 30caa73
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 2 deletions.
1 change: 1 addition & 0 deletions ludwig/features/sequence_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def get_feature_meta(column, preprocessing_parameters, backend):
vocab_file=preprocessing_parameters["vocab_file"],
unknown_symbol=preprocessing_parameters["unknown_symbol"],
padding_symbol=preprocessing_parameters["padding_symbol"],
ngram_size=preprocessing_parameters["ngram_size"],
processor=backend.df_engine,
)
max_length = min(preprocessing_parameters["max_sequence_length"], max_length)
Expand Down
1 change: 1 addition & 0 deletions ludwig/features/text_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def feature_meta(column, preprocessing_parameters, backend):
unknown_symbol=preprocessing_parameters["unknown_symbol"],
padding_symbol=preprocessing_parameters["padding_symbol"],
pretrained_model_name_or_path=preprocessing_parameters["pretrained_model_name_or_path"],
ngram_size=preprocessing_parameters["ngram_size"],
processor=backend.df_engine,
)
return (
Expand Down
7 changes: 7 additions & 0 deletions ludwig/schema/features/preprocessing/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ class SequencePreprocessingConfig(BasePreprocessingConfig):
parameter_metadata=FEATURE_METADATA[SEQUENCE][PREPROCESSING]["computed_fill_value"],
)

ngram_size: int = schema_utils.PositiveInteger(
default=2,
allow_none=False,
description="The size of the ngram when using the `ngram` tokenizer (e.g, 2 = bigram, 3 = trigram, etc.).",
parameter_metadata=FEATURE_METADATA[SEQUENCE][PREPROCESSING]["ngram_size"],
)


@register_preprocessor("sequence_output")
@dataclass(repr=False)
Expand Down
7 changes: 7 additions & 0 deletions ludwig/schema/features/preprocessing/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ class TextPreprocessingConfig(BasePreprocessingConfig):
parameter_metadata=FEATURE_METADATA[TEXT][PREPROCESSING]["computed_fill_value"],
)

ngram_size: int = schema_utils.PositiveInteger(
default=2,
allow_none=False,
description="The size of the ngram when using the `ngram` tokenizer (e.g, 2 = bigram, 3 = trigram, etc.).",
parameter_metadata=FEATURE_METADATA[TEXT][PREPROCESSING]["ngram_size"],
)


@register_preprocessor("text_output")
@dataclass(repr=False)
Expand Down
28 changes: 28 additions & 0 deletions ludwig/schema/metadata/feature_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,20 @@
literature_references=None,
internal_only=False,
),
"ngram_size": ParameterMetadata(
ui_display_name="n-gram size",
default_value_reasoning="Size of the n-gram when using the `ngram` tokenizer.",
example_value=3,
related_parameters=None,
other_information=None,
description_implications=None,
suggested_values=None,
suggested_values_reasoning=None,
commonly_used=False,
expected_impact=ExpectedImpact.UNKNOWN,
literature_references=None,
internal_only=False,
),
}
},
"set": {
Expand Down Expand Up @@ -1440,6 +1454,20 @@
literature_references=None,
internal_only=False,
),
"ngram_size": ParameterMetadata(
ui_display_name="n-gram size",
default_value_reasoning="Size of the n-gram when using the `ngram` tokenizer.",
example_value=3,
related_parameters=None,
other_information=None,
description_implications=None,
suggested_values=None,
suggested_values_reasoning=None,
commonly_used=False,
expected_impact=ExpectedImpact.UNKNOWN,
literature_references=None,
internal_only=False,
),
}
},
"timeseries": {
Expand Down
3 changes: 3 additions & 0 deletions ludwig/utils/strings_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def create_vocabulary(
start_symbol: str = START_SYMBOL,
stop_symbol: str = STOP_SYMBOL,
pretrained_model_name_or_path: str = None,
ngram_size: Optional[int] = None,
processor: DataFrameEngine = PANDAS,
):
"""Computes a vocabulary over the provided data frame.
Expand Down Expand Up @@ -236,6 +237,7 @@ def create_vocabulary(
start_symbol: String representation for the START symbol.
stop_symbol: String representation for the STOP symbol.
pretrained_model_name_or_path: Name/path to huggingface model.
ngram_size: Size of the n-gram when using `ngram` tokenizer.
processor: Which processor to use to process data.
Returns:
Expand All @@ -256,6 +258,7 @@ def create_vocabulary(
tokenizer = get_tokenizer_from_registry(tokenizer_type)(
vocab_file=vocab_file,
pretrained_model_name_or_path=pretrained_model_name_or_path,
ngram_size=ngram_size,
)

# Pre-trained huggingface tokenizer. Use the pre-existing vocabulary and special symbols.
Expand Down
19 changes: 18 additions & 1 deletion ludwig/utils/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,29 @@ def forward(self, v: Union[str, List[str], torch.Tensor]) -> Any:
for sequence in inputs:
split_sequence = sequence.strip().split(" ")
token_sequence: List[str] = []
for token in split_sequence:
for token in self.get_tokens(split_sequence):
if len(token) > 0:
token_sequence.append(token)
tokens.append(token_sequence)

return tokens[0] if isinstance(v, str) else tokens

def get_tokens(self, tokens: List[str]) -> List[str]:
return tokens


class NgramTokenizer(SpaceStringToListTokenizer):
"""Implements torchscript-compatible n-gram tokenization."""

def __init__(self, ngram_size: int = 2, **kwargs):
super().__init__()
self.n = ngram_size or 2

def get_tokens(self, tokens: List[str]) -> List[str]:
from torchtext.data.utils import ngrams_iterator

return list(ngrams_iterator(tokens, ngrams=self.n))


class SpacePunctuationStringToListTokenizer(torch.nn.Module):
"""Implements torchscript-compatible space_punct tokenization."""
Expand Down Expand Up @@ -802,6 +818,7 @@ def get_unk_token(self) -> str:
# Torchscript-compatible tokenizers. Torchtext tokenizers are also available below (requires torchtext>=0.12.0).
"space": SpaceStringToListTokenizer,
"space_punct": SpacePunctuationStringToListTokenizer,
"ngram": NgramTokenizer,
# Tokenizers not compatible with torchscript
"characters": CharactersToListTokenizer,
"underscore": UnderscoreStringToListTokenizer,
Expand Down
20 changes: 19 additions & 1 deletion tests/ludwig/utils/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torchtext
from transformers.models.bert.tokenization_bert import PRETRAINED_INIT_CONFIGURATION, PRETRAINED_VOCAB_FILES_MAP

from ludwig.utils.tokenizers import SKIP_TORCHTEXT_BERT_HF_MODEL_NAMES
from ludwig.utils.tokenizers import NgramTokenizer, SKIP_TORCHTEXT_BERT_HF_MODEL_NAMES


@pytest.mark.parametrize(
Expand Down Expand Up @@ -41,3 +41,21 @@ def test_bert_hf_tokenizer_parity(pretrained_model_name_or_path):
assert not isinstance(tokenizer_ids_only, HFTokenizer)
assert tokens == tokens_expected
assert token_ids == token_ids_expected


def test_ngram_tokenizer():
inputs = "Hello, I'm a single sentence!"
tokenizer = NgramTokenizer(n=2)
tokens_expected = [
"Hello,",
"I'm",
"a",
"single",
"sentence!",
"Hello, I'm",
"I'm a",
"a single",
"single sentence!",
]
tokens = tokenizer(inputs)
assert tokens == tokens_expected

0 comments on commit 30caa73

Please sign in to comment.