Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace as_target context managers by direct calls #18325

Merged
merged 11 commits into from
Jul 29, 2022
13 changes: 4 additions & 9 deletions src/transformers/models/m2m_100/tokenization_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Tokenization classes for M2M100."""
import json
import os
from contextlib import contextmanager
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -346,16 +345,12 @@ def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lan
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs

@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
def _switch_to_input_mode(self):
self.set_src_lang_special_tokens(self.src_lang)

def _switch_to_target_mode(self):
self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang: str) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
lang_token = self.get_lang_token(src_lang)
Expand Down
15 changes: 5 additions & 10 deletions src/transformers/models/marian/tokenization_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import os
import re
import warnings
from contextlib import contextmanager
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -281,18 +280,14 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> Lis
# We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id]

@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
def _switch_to_input_mode(self):
self.current_spm = self.spm_source
self.current_encoder = self.encoder

def _switch_to_target_mode(self):
self.current_spm = self.spm_target
if self.separate_vocabs:
self.current_encoder = self.target_encoder
yield
self.current_spm = self.spm_source
self.current_encoder = self.encoder

@property
def vocab_size(self) -> int:
Expand Down
15 changes: 5 additions & 10 deletions src/transformers/models/mbart/tokenization_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import os
from contextlib import contextmanager
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -340,15 +339,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)

@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def _switch_to_input_mode(self):
return self.set_src_lang_special_tokens(self.src_lang)

def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
Expand Down
15 changes: 5 additions & 10 deletions src/transformers/models/mbart/tokenization_mbart_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import os
from contextlib import contextmanager
from shutil import copyfile
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -240,15 +239,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)

@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def _switch_to_input_mode(self):
return self.set_src_lang_special_tokens(self.src_lang)

def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
Expand Down
15 changes: 5 additions & 10 deletions src/transformers/models/mbart50/tokenization_mbart50.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import os
from contextlib import contextmanager
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -337,15 +336,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)

@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def _switch_to_input_mode(self):
return self.set_src_lang_special_tokens(self.src_lang)

def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang: str) -> None:
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
Expand Down
15 changes: 5 additions & 10 deletions src/transformers/models/mbart50/tokenization_mbart50_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import os
from contextlib import contextmanager
from shutil import copyfile
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -211,15 +210,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)

@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def _switch_to_input_mode(self):
return self.set_src_lang_special_tokens(self.src_lang)

def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang: str) -> None:
"""Reset the special tokens to the source lang setting. prefix=[src_lang_code] and suffix=[eos]."""
Expand Down
62 changes: 60 additions & 2 deletions src/transformers/models/mctct/processing_mctct.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""
Speech processor class for M-CTC-T
"""
import warnings
from contextlib import contextmanager

from ...processing_utils import ProcessorMixin
Expand All @@ -39,6 +40,7 @@ class MCTCTProcessor(ProcessorMixin):
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
self.current_processor = self.feature_extractor
self._in_target_context_manager = False

def __call__(self, *args, **kwargs):
"""
Expand All @@ -47,7 +49,35 @@ def __call__(self, *args, **kwargs):
[`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to AutoTokenizer's
[`~AutoTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
"""
return self.current_processor(*args, **kwargs)
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)

if "raw_speech" in kwargs:
warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
audio = kwargs.pop("raw_speech")
else:
audio = kwargs.pop("audio", None)
text = kwargs.pop("text", None)
if len(args) > 0:
audio = args[0]
args = args[1:]

if audio is None and text is None:
raise ValueError("You need to specify either an `audio` or `text` input to process.")

if audio is not None:
inputs = self.feature_extractor(audio, *args, **kwargs)
if text is not None:
encodings = self.tokenizer(text, **kwargs)

if text is None:
return inputs
elif audio is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs

def batch_decode(self, *args, **kwargs):
"""
Expand All @@ -63,7 +93,28 @@ def pad(self, *args, **kwargs):
[`~MCTCTProcessor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's
[`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information.
"""
return self.current_processor.pad(*args, **kwargs)
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor.pad(*args, **kwargs)

input_features = kwargs.pop("input_features", None)
labels = kwargs.pop("labels", None)
if len(args) > 0:
input_features = args[0]
args = args[1:]

if input_features is not None:
input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
if labels is not None:
labels = self.tokenizer.pad(labels, **kwargs)

if labels is None:
return input_features
elif input_features is None:
return labels
else:
input_features["labels"] = labels["input_ids"]
return input_features

def decode(self, *args, **kwargs):
"""
Expand All @@ -77,6 +128,13 @@ def as_target_processor(self):
"""
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning MCTCT.
"""
warnings.warn(
"`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
"labels by using the argument `text` of the regular `__call__` method (either in the same call as "
"your audio inputs, or in a separate call."
)
self._in_target_context_manager = True
self.current_processor = self.tokenizer
yield
self.current_processor = self.feature_extractor
self._in_target_context_manager = False
15 changes: 5 additions & 10 deletions src/transformers/models/nllb/tokenization_nllb.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import os
from contextlib import contextmanager
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -386,15 +385,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)

@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def _switch_to_input_mode(self):
return self.set_src_lang_special_tokens(self.src_lang)

def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
Expand Down
15 changes: 5 additions & 10 deletions src/transformers/models/nllb/tokenization_nllb_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import os
from contextlib import contextmanager
from shutil import copyfile
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -284,15 +283,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)

@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def _switch_to_input_mode(self):
return self.set_src_lang_special_tokens(self.src_lang)

def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
Expand Down
15 changes: 5 additions & 10 deletions src/transformers/models/plbart/tokenization_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import os
from contextlib import contextmanager
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -441,15 +440,11 @@ def prepare_seq2seq_batch(
self.tgt_lang = tgt_lang
return super().prepare_seq2seq_batch(src_texts, tgt_texts, **kwargs)

@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.set_tgt_lang_special_tokens(self.tgt_lang)
yield
self.set_src_lang_special_tokens(self.src_lang)
def _switch_to_input_mode(self):
return self.set_src_lang_special_tokens(self.src_lang)

def _switch_to_target_mode(self):
return self.set_tgt_lang_special_tokens(self.tgt_lang)

def set_src_lang_special_tokens(self, src_lang) -> None:
"""Reset the special tokens to the source lang setting. No prefix and suffix=[eos, src_lang_code]."""
Expand Down
13 changes: 4 additions & 9 deletions src/transformers/models/rag/tokenization_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Tokenization classes for RAG."""
import os
import warnings
from contextlib import contextmanager
from typing import List, Optional

from ...tokenization_utils_base import BatchEncoding
Expand Down Expand Up @@ -68,16 +67,12 @@ def batch_decode(self, *args, **kwargs):
def decode(self, *args, **kwargs):
return self.generator.decode(*args, **kwargs)

@contextmanager
def as_target_tokenizer(self):
"""
Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
sequence-to-sequence models that need a slightly different processing for the labels.
"""
self.current_tokenizer = self.generator
yield
def _switch_to_input_mode(self):
self.current_tokenizer = self.question_encoder

def _switch_to_target_mode(self):
self.current_tokenizer = self.generator

def prepare_seq2seq_batch(
self,
src_texts: List[str],
Expand Down
Loading