Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated characters, underscore and comma preprocessors to be TorchScriptable. #3602

Merged
merged 7 commits into from
Sep 14, 2023
76 changes: 54 additions & 22 deletions ludwig/utils/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""

import logging
import re
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Union

Expand All @@ -29,11 +28,7 @@
logger = logging.getLogger(__name__)
torchtext_version = torch.torch_version.TorchVersion(torchtext.__version__)

SPACE_PUNCTUATION_REGEX = re.compile(r"\w+|[^\w\s]")
COMMA_REGEX = re.compile(r"\s*,\s*")
UNDERSCORE_REGEX = re.compile(r"\s*_\s*")

TORCHSCRIPT_COMPATIBLE_TOKENIZERS = {"space", "space_punct"}
TORCHSCRIPT_COMPATIBLE_TOKENIZERS = {"space", "space_punct", "comma", "underscore", "characters"}
TORCHTEXT_0_12_0_TOKENIZERS = {"sentencepiece", "clip", "gpt2bpe"}
TORCHTEXT_0_13_0_TOKENIZERS = {"bert"}

Expand All @@ -50,14 +45,61 @@ def __call__(self, text: str):
pass


class CharactersToListTokenizer(BaseTokenizer):
def __call__(self, text):
return [char for char in text]
class StringSplitTokenizer(torch.nn.Module):
martindavis marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, split_string, **kwargs):
super().__init__()
self.split_string = split_string

def forward(self, v: Union[str, List[str], torch.Tensor]) -> Any:
if isinstance(v, torch.Tensor):
raise ValueError(f"Unsupported input: {v}")

inputs: List[str] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you are adapting an existing implementation, though this seems more complicated than I would expect (for example, why do we have a get_tokens() function that returns its own input?).

@geoffreyangus, ooc does this also look strange to you, or is this imposed on us by torchscript?

Copy link
Contributor

@geoffreyangus geoffreyangus Sep 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like NgramTokenizer, which subclasses SpaceStringToListTokenizer (which in turn subclasses the new StringSplitTokenizer), seems to override get_tokens: https://github.com/ludwig-ai/ludwig/pull/3602/files#diff-5cbace55f4f4fd07725c061b9f981b83fe43cb53b0045cf1257c9fb5d4931f0dR132-R142

# Ludwig calls map on List[str] objects, so we need to handle individual strings as well.
if isinstance(v, str):
inputs.append(v)
else:
inputs.extend(v)

tokens: List[List[str]] = []
for sequence in inputs:
split_sequence = sequence.strip().split(self.split_string)
token_sequence: List[str] = []
for token in self.get_tokens(split_sequence):
if len(token) > 0:
token_sequence.append(token)
tokens.append(token_sequence)

return tokens[0] if isinstance(v, str) else tokens

class SpaceStringToListTokenizer(torch.nn.Module):
def get_tokens(self, tokens: List[str]) -> List[str]:
return tokens


class SpaceStringToListTokenizer(StringSplitTokenizer):
"""Implements torchscript-compatible whitespace tokenization."""

def __init__(self, **kwargs):
super().__init__(split_string=" ", **kwargs)


class UnderscoreStringToListTokenizer(StringSplitTokenizer):
"""Implements torchscript-compatible underscore tokenization."""

def __init__(self, **kwargs):
super().__init__(split_string="_", **kwargs)


class CommaStringToListTokenizer(StringSplitTokenizer):
"""Implements torchscript-compatible comma tokenization."""

def __init__(self, **kwargs):
super().__init__(split_string=",", **kwargs)


class CharactersToListTokenizer(torch.nn.Module):
"""Implements torchscript-compatible characters tokenization."""

def __init__(self, **kwargs):
super().__init__()

Expand All @@ -74,7 +116,7 @@ def forward(self, v: Union[str, List[str], torch.Tensor]) -> Any:

tokens: List[List[str]] = []
for sequence in inputs:
split_sequence = sequence.strip().split(" ")
split_sequence = [char for char in sequence]
token_sequence: List[str] = []
for token in self.get_tokens(split_sequence):
if len(token) > 0:
Expand Down Expand Up @@ -142,16 +184,6 @@ def forward(self, v: Union[str, List[str], torch.Tensor]) -> Any:
return tokens[0] if isinstance(v, str) else tokens


class UnderscoreStringToListTokenizer(BaseTokenizer):
def __call__(self, text):
return UNDERSCORE_REGEX.split(text.strip())


class CommaStringToListTokenizer(BaseTokenizer):
def __call__(self, text):
return COMMA_REGEX.split(text.strip())


class UntokenizedStringToListTokenizer(BaseTokenizer):
def __call__(self, text):
return [text]
Expand Down Expand Up @@ -855,10 +887,10 @@ def _set_pad_token(self) -> None:
"space": SpaceStringToListTokenizer,
"space_punct": SpacePunctuationStringToListTokenizer,
"ngram": NgramTokenizer,
# Tokenizers not compatible with torchscript
"characters": CharactersToListTokenizer,
"underscore": UnderscoreStringToListTokenizer,
"comma": CommaStringToListTokenizer,
# Tokenizers not compatible with torchscript
"untokenized": UntokenizedStringToListTokenizer,
"stripped": StrippedStringToListTokenizer,
"english_tokenize": EnglishTokenizer,
Expand Down
9 changes: 8 additions & 1 deletion tests/ludwig/utils/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torchtext

from ludwig.utils.tokenizers import EnglishLemmatizeFilterTokenizer, NgramTokenizer
from ludwig.utils.tokenizers import EnglishLemmatizeFilterTokenizer, NgramTokenizer, StringSplitTokenizer

TORCHTEXT_0_14_0_HF_NAMES = [
"bert-base-uncased",
Expand Down Expand Up @@ -73,6 +73,13 @@ def test_ngram_tokenizer():
assert tokens == tokens_expected


def test_string_split_tokenizer():
inputs = "Multiple,Elements,Are here!"
tokenizer = StringSplitTokenizer(",")
tokens = tokenizer(inputs)
assert tokens == ["Multiple", "Elements", "Are here!"]


def test_english_lemmatize_filter_tokenizer():
inputs = "Hello, I'm a single sentence!"
tokenizer = EnglishLemmatizeFilterTokenizer()
Expand Down