From e00e47fbc8ea5bb13f4acf9674d2ec0979416542 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 11 Nov 2020 22:52:57 -0500 Subject: [PATCH 1/2] TAPAS --- src/transformers/tokenization_tapas.py | 1356 ++++++++++++++++-------- tests/test_tokenization_tapas.py | 1195 +++++++++++++++++++++ 2 files changed, 2100 insertions(+), 451 deletions(-) create mode 100644 tests/test_tokenization_tapas.py diff --git a/src/transformers/tokenization_tapas.py b/src/transformers/tokenization_tapas.py index 33cffb0d32c4..2484f9b38b48 100644 --- a/src/transformers/tokenization_tapas.py +++ b/src/transformers/tokenization_tapas.py @@ -24,27 +24,21 @@ import os import re import unicodedata +import warnings from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Text, Tuple, Union +from typing import Callable, Dict, Generator, List, Optional, Text, Tuple, Union +import pandas as pd import torch from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from .tokenization_utils_base import ( - ENCODE_KWARGS_DOCSTRING, - ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING, - INIT_TOKENIZER_DOCSTRING, - AddedToken, BatchEncoding, EncodedInput, - EncodedInputPair, PaddingStrategy, PreTokenizedInput, - PreTokenizedInputPair, - PreTrainedTokenizerBase, TensorType, TextInput, - TextInputPair, TruncationStrategy, ) from .utils import logging @@ -71,6 +65,9 @@ } +TableValue = collections.namedtuple("TokenValue", ["token", "column_id", "row_id"]) + + @dataclass(frozen=True) class TokenCoordinates: column_index: int @@ -235,6 +232,12 @@ def __init__( mask_token=mask_token, tokenize_chinese_chars=tokenize_chinese_chars, strip_accents=strip_accents, + cell_trim_length=cell_trim_length, + max_column_id=max_column_id, + max_row_id=max_row_id, + strip_column_names=strip_column_names, + update_answer_coordinates=update_answer_coordinates, + drop_rows_to_fit=drop_rows_to_fit, **kwargs, ) @@ -323,6 +326,829 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = index += 1 return (vocab_file,) + def create_attention_mask_from_sequences(self, query_ids: List[int], table_values: List[TableValue]) -> List[int]: + table_ids = list(zip(*table_values))[0] if table_values else [] + return [1] * (1 + len(query_ids) + 1) + [0] * len(table_ids) + + def create_segment_token_type_ids_from_sequences( + self, query_ids: List[int], table_values: List[TableValue] + ) -> List[int]: + table_ids = list(zip(*table_values))[0] if table_values else [] + return [0] * (1 + len(query_ids) + 1) + [1] * len(table_ids) + + def create_column_token_type_ids_from_sequences( + self, query_ids: List[int], table_values: List[TableValue] + ) -> List[int]: + table_column_ids = list(zip(*table_values))[1] if table_values else [] + return [0] * (1 + len(query_ids) + 1) + list(table_column_ids) + + def create_row_token_type_ids_from_sequences( + self, query_ids: List[int], table_values: List[TableValue] + ) -> List[int]: + table_row_ids = list(zip(*table_values))[2] if table_values else [] + return [0] * (1 + len(query_ids) + 1) + list(table_row_ids) + + def create_label_ids_from_sequences_and_answers( + self, query_ids: List[int], table_values: List[TableValue] + ) -> List[int]: + table_row_ids = list(zip(*table_values))[2] if table_values else [] + return [0] * (1 + len(query_ids) + 1) + list(table_row_ids) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. + + This implementation does not add special tokens and this method should be overridden in a subclass. + + Args: + token_ids_0 (:obj:`List[int]`): The first tokenized sequence. + token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence. + + Returns: + :obj:`List[int]`: The model input with special tokens. + """ + if token_ids_1 is None: + raise ValueError("With TAPAS, you must provide both question IDs and table IDs.") + + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + token_ids_1 + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` method. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formatted with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + return [1] + ([0] * len(token_ids_0)) + [1] + + def __call__( + self, + table: pd.DataFrame, + queries: Optional[ + Union[ + List[TextInput], + List[PreTokenizedInput], + List[EncodedInput], + ] + ] = None, + answer_coordinates: Optional[List[Tuple]] = None, + answer_texts: Optional[List[TextInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences. + + Args: + text (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + :obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + text_pair (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + :obj:`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + """ + assert isinstance(table, pd.DataFrame), "Table must be of type pd.DataFrame" + + # Input type checking for clearer error + assert ( + queries is None + or isinstance(queries, str) + or ( + isinstance(queries, (list, tuple)) + and ( + len(queries) == 0 + or ( + isinstance(queries[0], str) + or ( + isinstance(queries[0], (list, tuple)) + and (len(queries[0]) == 0 or isinstance(queries[0][0], str)) + ) + ) + ) + ) + ), ( + "queries input must of type `str` (single example), `List[str]` (batch or single pretokenized example) " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + is_batched = isinstance(queries, (list, tuple)) + + if is_batched: + return self.batch_encode_plus( + table=table, + queries=queries, + answer_coordinates=answer_coordinates, + answer_texts=answer_texts, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus( + table=table, + query=queries, + answer_coordinate=answer_coordinates, + answer_text=answer_texts, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def batch_encode_plus( + self, + table: pd.DataFrame, + queries: Optional[ + Union[ + List[TextInput], + List[PreTokenizedInput], + List[EncodedInput], + ] + ] = None, + answer_coordinates: Optional[List[Tuple]] = None, + answer_texts: Optional[List[TextInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + if return_token_type_ids is not None and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + if (answer_coordinates and not answer_texts) or (not answer_coordinates and answer_texts): + raise ValueError("In case you provide answers, both answer_coordinate and answer_text should be provided") + elif answer_coordinates is None and answer_texts is None: + answer_coordinates = answer_texts = [None] * len(queries) + + if "is_split_into_words" in kwargs: + raise NotImplementedError("Currently TapasTokenizer only supports questions as strings.") + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers." + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + if "return_lengths" in kwargs: + if verbose: + warnings.warn( + "The PreTrainedTokenizerBase.prepare_for_model `return_lengths` parameter is deprecated. " + "Please use `return_length` instead.", + FutureWarning, + ) + return_length = kwargs["return_lengths"] + + return self._batch_encode_plus( + table=table, + queries=queries, + answer_coordinates=answer_coordinates, + answer_texts=answer_texts, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus( + self, + table, + queries: Union[ + List[TextInput], + List[PreTokenizedInput], + List[EncodedInput], + ], + answer_coordinates: Optional[List[Tuple]] = None, + answer_texts: Optional[List[TextInput]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = True, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + + table_tokens = self._tokenize_table(table) + + queries_tokens = [] + queries_ids = [] + for query in queries: + query_tokens = self.tokenize(query) + queries_tokens.append(query_tokens) + queries_ids.append(self.convert_tokens_to_ids(query_tokens)) + + num_rows = self._get_num_rows(table, self.drop_rows_to_fit) + num_columns = self._get_num_columns(table) + + _, _, num_tokens = self._get_table_boundaries(table_tokens) + + table_data = list(self._get_table_values(table_tokens, num_columns, num_rows, num_tokens)) + + table_ids = list(zip(*table_data))[0] if len(table_data) > 0 else list(zip(*table_data)) + table_ids = self.convert_tokens_to_ids(list(table_ids)) + + batch_outputs = self._batch_prepare_for_model( + table_ids, + queries_ids, + table, + queries, + table_data=table_data, + queries_tokens=queries_tokens, + answer_coordinates=answer_coordinates, + answer_texts=answer_texts, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + def _batch_prepare_for_model( + self, + table_ids: List[int], + queries_ids: List[List[int]], + raw_table: pd.DataFrame, + raw_queries: Union[ + TextInput, + PreTokenizedInput, + EncodedInput, + ], + answer_coordinates: Optional[List[Tuple]] = None, + answer_texts: Optional[List[TextInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = True, + return_attention_mask: Optional[bool] = True, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs + ) -> BatchEncoding: + """ + Prepares a sequence of strings (queries) related to a table so that it can be used by the model. It creates + input ids, adds special tokens, truncates the table if overflowing (if the drop_rows_to_fit parameter is set to + True) while taking into account the special tokens and manages a moving window (with user defined stride) for + overflowing tokens + + This function is based on prepare_for_model (but in Tapas, training examples depend on each other, so we + defined it at a batch level) + + Args: + table: Pandas dataframe + queries: List of Strings, containing questions related to the table + """ + batch_outputs = {} + + if "table_data" in kwargs and "queries_tokens" in kwargs: + table_data = kwargs["table_data"] + queries_tokens = kwargs["queries_tokens"] + else: + table_data = None + queries_tokens = [None] * len(queries_ids) + + for query_ids, raw_query, query_tokens, answer_coordinate, answer_text in zip( + queries_ids, raw_queries, queries_tokens, answer_coordinates, answer_texts + ): + outputs = self.prepare_for_model( + table_ids, + query_ids, + raw_table, + raw_query, + table_data=table_data, + query_tokens=query_tokens, + answer_coordinate=answer_coordinate, + answer_text=answer_text, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + def encode( + self, + table: pd.DataFrame, + query: Optional[ + Union[ + TextInput, + PreTokenizedInput, + EncodedInput, + ] + ] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ) -> List[int]: + encoded_inputs = self.encode_plus( + table, + query=query, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + def encode_plus( + self, + table: pd.DataFrame, + query: Optional[ + Union[ + TextInput, + PreTokenizedInput, + EncodedInput, + ] + ] = None, + answer_coordinate: Optional[List[Tuple]] = None, + answer_text: Optional[List[TextInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ) -> BatchEncoding: + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + if return_token_type_ids is not None and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + if (answer_coordinate and not answer_text) or (not answer_coordinate and answer_text): + raise ValueError("In case you provide answers, both answer_coordinate and answer_text should be provided") + + if "is_split_into_words" in kwargs: + raise NotImplementedError("Currently TapasTokenizer only supports questions as strings.") + + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers." + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + if "return_lengths" in kwargs: + if verbose: + warnings.warn( + "The PreTrainedTokenizerBase.prepare_for_model `return_lengths` parameter is deprecated. " + "Please use `return_length` instead.", + FutureWarning, + ) + return_length = kwargs["return_lengths"] + + return self._encode_plus( + table=table, + query=query, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _encode_plus( + self, + table: pd.DataFrame, + query: Union[ + TextInput, + PreTokenizedInput, + EncodedInput, + ], + answer_coordinate: Optional[List[Tuple]] = None, + answer_text: Optional[List[TextInput]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = True, + return_attention_mask: Optional[bool] = True, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs + ): + if query is None: + query = "" + logger.warning( + "TAPAS is a question answering model but you have not passed a query. Please be aware that the " + "model will probably not behave correctly." + ) + + table_tokens = self._tokenize_table(table) + query_tokens = self.tokenize(query) + + num_rows = self._get_num_rows(table, self.drop_rows_to_fit) + num_columns = self._get_num_columns(table) + + _, _, num_tokens = self._get_table_boundaries(table_tokens) + + table_data = list(self._get_table_values(table_tokens, num_columns, num_rows, num_tokens)) + + query_ids = self.convert_tokens_to_ids(query_tokens) + table_ids = list(zip(*table_data))[0] if len(table_data) > 0 else list(zip(*table_data)) + table_ids = self.convert_tokens_to_ids(list(table_ids)) + + return self.prepare_for_model( + table_ids, + query_ids, + table, + query, + table_data=table_data, + query_tokens=query_tokens, + answer_coordinate=answer_coordinate, + answer_text=answer_text, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + def prepare_for_model( + self, + table_ids: List[int], + query_ids: List[int], + raw_table: pd.DataFrame, + raw_query: Union[ + TextInput, + PreTokenizedInput, + EncodedInput, + ], + answer_coordinate: Optional[List[Tuple]] = None, + answer_text: Optional[List[TextInput]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = True, + return_attention_mask: Optional[bool] = True, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs + ) -> BatchEncoding: + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + encoded_inputs = {} + + # This can be retrieved from the encoding step, which prevents recomputing. + # We still need to handle recomputing as `prepare_for_model` should be callable on raw IDs/table/query as well. + if ( + "table_data" not in kwargs + or "query_tokens" not in kwargs + or ( + ("table_data" in kwargs and kwargs["table_data"] is None) + and ("query_tokens" in kwargs and kwargs["query_tokens"] is None) + ) + ): + table_tokens = self._tokenize_table(raw_table) + num_rows = self._get_num_rows(raw_table, self.drop_rows_to_fit) + num_columns = self._get_num_columns(raw_table) + _, _, num_tokens = self._get_table_boundaries(table_tokens) + table_data = list(self._get_table_values(table_tokens, num_columns, num_rows, num_tokens)) + query_tokens = self.tokenize(raw_query) + else: + table_data = kwargs["table_data"] + query_tokens = kwargs["query_tokens"] + + total_len = ( + len(query_ids) + len(table_ids) + (self.num_special_tokens_to_add(pair=True) if add_special_tokens else 0) + ) + + overflowing_tokens = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + query_ids, table_ids, overflowing_tokens = self.truncate_sequences( + query_ids, + pair_ids=table_ids, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + if add_special_tokens: + input_ids = self.build_inputs_with_special_tokens(query_ids, table_ids) + else: + input_ids = query_ids + table_ids + + encoded_inputs["input_ids"] = input_ids + + segment_ids = self.create_segment_token_type_ids_from_sequences(query_ids, table_data) + column_ids = self.create_column_token_type_ids_from_sequences(query_ids, table_data) + row_ids = self.create_row_token_type_ids_from_sequences(query_ids, table_data) + prev_label_ids = [0] * len(row_ids) + + column_ranks, inv_column_ranks, columns_to_numeric_values = self._get_numeric_column_ranks( + column_ids, row_ids, raw_table + ) + numeric_relations = self._get_numeric_relations( + raw_query, column_ids, row_ids, raw_table, columns_to_numeric_values + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + if return_attention_mask: + attention_mask = self.create_attention_mask_from_sequences(query_ids, table_data) + encoded_inputs["attention_mask"] = attention_mask + + if answer_coordinate is not None and answer_text is not None: + label_ids = self.get_answer_ids( + column_ids, row_ids, table_data, query_tokens, answer_text, answer_coordinate + ) + numeric_values = self._get_numeric_values(raw_table, column_ids, row_ids, columns_to_numeric_values) + numeric_values_scale = self._get_numeric_values_scale(raw_table, column_ids, row_ids) + + encoded_inputs["label_ids"] = label_ids + encoded_inputs["numeric_values"] = numeric_values + encoded_inputs["numeric_values_scale"] = numeric_values_scale + + if return_token_type_ids: + token_type_ids = [ + segment_ids, + column_ids, + row_ids, + prev_label_ids, + column_ranks, + inv_column_ranks, + numeric_relations, + ] + + token_type_ids = [list(ids) for ids in list(zip(*token_type_ids))] + encoded_inputs["token_type_ids"] = token_type_ids + + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(query_ids, table_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(input_ids) + + # Check lengths + if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose: + if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False): + logger.warning( + "Token indices sequence length is longer than the specified maximum sequence length " + "for this model ({} > {}). Running this sequence through the model will result in " + "indexing errors".format(len(encoded_inputs["input_ids"]), self.model_max_length) + ) + self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + def _tokenize_table( self, table=None, @@ -384,7 +1210,7 @@ def _get_token_budget(self, question_tokens): """ return self.model_max_length - self._question_encoding_cost(question_tokens) - def _get_table_values(self, table, num_columns, num_rows, num_tokens): + def _get_table_values(self, table, num_columns, num_rows, num_tokens) -> Generator[TableValue, None, None]: """Iterates over partial table and returns token, column and row indexes.""" for tc in table.selected_tokens: # First row is header row. @@ -401,7 +1227,7 @@ def _get_table_values(self, table, num_columns, num_rows, num_tokens): word_begin_index -= 1 if word_begin_index >= num_tokens: continue - yield token, tc.column_index + 1, tc.row_index + yield TableValue(token, tc.column_index + 1, tc.row_index) def _get_table_boundaries(self, table): """Return maximal number of rows, columns and tokens.""" @@ -528,7 +1354,7 @@ def _get_cell_token_indexes(self, column_ids, row_ids, column_id, row_id): if column_ids[index] - 1 == column_id and row_ids[index] - 1 == row_id: yield index - def _add_numeric_column_ranks(self, column_ids, row_ids, table, features): + def _get_numeric_column_ranks(self, column_ids, row_ids, table): """Adds column ranks for all numeric columns.""" ranks = [0] * len(column_ids) @@ -565,10 +1391,7 @@ def _add_numeric_column_ranks(self, column_ids, row_ids, table, features): ranks[index] = rank + 1 inv_ranks[index] = len(unique_values) - rank - features["column_ranks"] = ranks - features["inv_column_ranks"] = inv_ranks - - return features, columns_to_numeric_values + return ranks, inv_ranks, columns_to_numeric_values def _get_numeric_sort_key_fn(self, table_numeric_values, value): """ @@ -591,16 +1414,15 @@ def _get_numeric_sort_key_fn(self, table_numeric_values, value): except ValueError: return None - def _add_numeric_relations(self, question, column_ids, row_ids, table, features, columns_to_numeric_values): + def _get_numeric_relations(self, question, column_ids, row_ids, table, columns_to_numeric_values): """ - Adds numeric relation embeddings to 'features' + Return numeric relations embeddings Args: question: The question, numeric values are used. column_ids: Maps word piece position to column id. row_ids: Maps word piece position to row id. table: The table containing the numeric cell values. - features: Output. columns_to_numeric_values: Dictionary that maps column indices to numeric values. """ @@ -633,12 +1455,10 @@ def _add_numeric_relations(self, question, column_ids, row_ids, table, features, for cell_token_index in self._get_cell_token_indexes(column_ids, row_ids, column_index, row_index): numeric_relations[cell_token_index] = relation_set_index - features["numeric_relations"] = numeric_relations - - return features + return numeric_relations - def _add_numeric_values(self, table, token_ids_dict, features, columns_to_numeric_values): - """Adds numeric values for computation of answer loss.""" + def _get_numeric_values(self, table, column_ids, row_ids, columns_to_numeric_values): + """Returns numeric values for computation of answer loss.""" numeric_values = [float("nan")] * self.model_max_length @@ -659,17 +1479,13 @@ def _add_numeric_values(self, table, token_ids_dict, features, columns_to_numeri if float_value == float("inf"): continue - for index in self._get_cell_token_indexes( - token_ids_dict["column_ids"], token_ids_dict["row_ids"], col_index, row_index - ): + for index in self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index): numeric_values[index] = float_value - features["numeric_values"] = numeric_values - - return features + return numeric_values - def _add_numeric_values_scale(self, table, token_ids_dict, features): - """Adds a scale to each token to down weigh the value of long words.""" + def _get_numeric_values_scale(self, table, column_ids, row_ids): + """Returns a scale to each token to down weigh the value of long words.""" numeric_values_scale = [1.0] * self.model_max_length @@ -681,20 +1497,13 @@ def _add_numeric_values_scale(self, table, token_ids_dict, features): for col_index in range(num_columns): for row_index in range(num_rows): - indices = [ - index - for index in self._get_cell_token_indexes( - token_ids_dict["column_ids"], token_ids_dict["row_ids"], col_index, row_index - ) - ] + indices = [index for index in self._get_cell_token_indexes(column_ids, row_ids, col_index, row_index)] num_indices = len(indices) if num_indices > 1: for index in indices: numeric_values_scale[index] = float(num_indices) - features["numeric_values_scale"] = numeric_values_scale - - return features + return numeric_values_scale def _pad_to_seq_length(self, inputs): while len(inputs) > self.model_max_length: @@ -702,110 +1511,6 @@ def _pad_to_seq_length(self, inputs): while len(inputs) < self.model_max_length: inputs.append(0) - def _to_features(self, tokens, token_ids_dict, table, question): - """ - Produces a dict of features. This function creates input ids, attention mask, token type ids (except the prev - label ids), as well as numeric value and numeric value scale. - """ - tokens = list(tokens) - token_ids_dict = {key: list(values) for key, values in token_ids_dict.items()} - - length = len(tokens) - for values in token_ids_dict.values(): - if len(values) != length: - raise ValueError("Inconsistent length") - - # currently the input ids, mask and token type ids are created here - # also, padding and truncation up to max length is done here (see function _pad_to_seq_length) - input_ids = self.convert_tokens_to_ids(tokens) - attention_mask = [1] * len(input_ids) - - self._pad_to_seq_length(input_ids) - self._pad_to_seq_length(attention_mask) - for values in token_ids_dict.values(): - self._pad_to_seq_length(values) - - assert len(input_ids) == self.model_max_length - assert len(attention_mask) == self.model_max_length - for values in token_ids_dict.values(): - assert len(values) == self.model_max_length - - features = {} - features["input_ids"] = input_ids - features["attention_mask"] = attention_mask - for key, values in sorted(token_ids_dict.items()): - features[key] = values - - features, columns_to_numeric_values = self._add_numeric_column_ranks( - token_ids_dict["column_ids"], token_ids_dict["row_ids"], table, features - ) - - features = self._add_numeric_relations( - question, - token_ids_dict["column_ids"], - token_ids_dict["row_ids"], - table, - features, - columns_to_numeric_values, - ) - - # finally, add numeric values and numeric values scale (only needed in case of regression loss calculation) - # so they should only be returned in case answer_coordinates + answer_texts are provided - - features = self._add_numeric_values(table, token_ids_dict, features, columns_to_numeric_values) - - features = self._add_numeric_values_scale(table, token_ids_dict, features) - - # we do not add table id and table id hash (was used in the original implementation) - # if table: - # features['table_id'] = create_string_feature([table.table_id.encode('utf8')]) - # features['table_id_hash'] = create_int_feature([fingerprint(table.table_id) % _MAX_INT]) - - return features - - def _to_trimmed_features( - self, - question, - table, - question_tokens, - tokenized_table, - num_columns, - num_rows, - drop_rows_to_fit=False, - ): - """Finds optimal number of table tokens to include and serializes.""" - init_num_rows = num_rows - while True: - num_tokens = self._get_max_num_tokens( - question_tokens, - tokenized_table, - num_rows=num_rows, - num_columns=num_columns, - ) - if num_tokens is not None: - # We could fit the table. - break - if not drop_rows_to_fit or num_rows == 0: - raise ValueError("Sequence too long") - # Try to drop a row to fit the table. - num_rows -= 1 - - serialized_example = self._serialize(question_tokens, tokenized_table, num_columns, num_rows, num_tokens) - - assert len(serialized_example.tokens) <= self.model_max_length - - feature_dict = { - "column_ids": serialized_example.column_ids, - "row_ids": serialized_example.row_ids, - "segment_ids": serialized_example.segment_ids, - } - - features = self._to_features(serialized_example.tokens, feature_dict, table=table, question=question) - - return serialized_example, features - - #### Everything related to label ids calculation #### - def _get_all_answer_ids_from_coordinates( self, column_ids, @@ -832,11 +1537,11 @@ def _get_all_answer_ids(self, column_ids, row_ids, question, answer_coordinates) the TSV format, the coordinates are given as (row, column) tuples. Here, we swap them to (column, row) format. """ - def _to_coordinates(question, answer_coordinates_question): + def _to_coordinates(answer_coordinates_question): return [(coords[1], coords[0]) for coords in answer_coordinates_question] return self._get_all_answer_ids_from_coordinates( - column_ids, row_ids, answers_list=(_to_coordinates(question, answer_coordinates)) + column_ids, row_ids, answers_list=(_to_coordinates(answer_coordinates)) ) def _find_tokens(self, text, segment): @@ -928,329 +1633,78 @@ def get_answer_ids( ) return self._get_answer_ids(column_ids, row_ids, question, answer_coordinates_question) - #### End of everything related to label ids calculation #### - - def batch_encode_plus( - self, - table, - queries: Union[ - List[TextInput], - List[PreTokenizedInput], - List[EncodedInput], - ], - answer_coordinates: Optional[List[Tuple]] = None, - answer_texts: Optional[List[TextInput]] = None, - add_special_tokens: bool = True, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, - max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, - pad_to_multiple_of: Optional[int] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = True, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - **kwargs - ) -> BatchEncoding: - """ - Tokenize and prepare for the model a list of one or more sequences related to a table. .. warning:: This method - is deprecated, ``__call__`` should be used instead - - Args: - queries (:obj:`List[str]`): - Batch of sequences (queries) related to a table to be encoded. This is a list of string-sequences (see - details in ``encode_plus``). - """ - - # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' - # padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( - # padding=padding, - # truncation=truncation, - # max_length=max_length, - # pad_to_multiple_of=pad_to_multiple_of, - # verbose=verbose, - # **kwargs, - # ) - - return self._batch_encode_plus( - table=table, - queries=queries, - answer_coordinates=answer_coordinates, - answer_texts=answer_texts, - add_special_tokens=add_special_tokens, - padding_strategy=padding_strategy, - truncation_strategy=truncation_strategy, - max_length=max_length, - stride=stride, - is_split_into_words=is_split_into_words, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors=return_tensors, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - **kwargs, - ) - - def _batch_encode_plus( + def _pad( self, - table, - queries: Union[ - List[TextInput], - List[PreTokenizedInput], - List[EncodedInput], - ], - answer_coordinates: Optional[List[Tuple]] = None, - answer_texts: Optional[List[TextInput]] = None, - add_special_tokens: bool = True, - padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], max_length: Optional[int] = None, - stride: int = 0, - is_split_into_words: bool = False, - pad_to_multiple_of: Optional[int] = None, - return_tensors: Optional[Union[str, TensorType]] = None, - return_token_type_ids: Optional[bool] = True, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - **kwargs - ) -> BatchEncoding: - - if return_offsets_mapping: - raise NotImplementedError( - "return_offset_mapping is not available when using Python tokenizers." - "To use this feature, change your tokenizer to one deriving from " - "transformers.PreTrainedTokenizerFast." - ) - - if "is_pretokenized" in kwargs: - warnings.warn( - "`is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.", - FutureWarning, - ) - - if "is_split_into_words" in kwargs: - raise NotImplementedError("Currently TapasTokenizer only supports questions as strings.") - - batch_outputs = self._batch_prepare_for_model( - table=table, - queries=queries, - answer_coordinates=answer_coordinates, - answer_texts=answer_texts, - add_special_tokens=add_special_tokens, - padding_strategy=padding_strategy, - truncation_strategy=truncation_strategy, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_token_type_ids=return_token_type_ids, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_length=return_length, - return_tensors=return_tensors, - verbose=verbose, - ) - - return BatchEncoding(batch_outputs) - - def _batch_prepare_for_model( - self, - table, - queries: Union[ - List[TextInput], - List[PreTokenizedInput], - List[EncodedInput], - ], - answer_coordinates: Optional[List[Tuple]] = None, - answer_texts: Optional[List[TextInput]] = None, - add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, - truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, - max_length: Optional[int] = None, - stride: int = 0, pad_to_multiple_of: Optional[int] = None, - return_tensors: Optional[str] = None, - return_token_type_ids: Optional[bool] = True, return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_length: bool = False, - verbose: bool = True, - **kwargs - ) -> BatchEncoding: + ) -> dict: """ - Prepares a sequence of strings (queries) related to a table so that it can be used by the model. It creates - input ids, adds special tokens, truncates the table if overflowing (if the drop_rows_to_fit parameter is set to - True) while taking into account the special tokens and manages a moving window (with user defined stride) for - overflowing tokens - - This function is based on prepare_for_model (but in Tapas, training examples depend on each other, so we - defined it at a batch level) + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) Args: - table: Pandas dataframe - queries: List of Strings, containing questions related to the table + encoded_inputs: Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + >= 7.5 (Volta). + return_attention_mask: (optional) Set to False to avoid returning attention mask (default: set to model specifics) """ - - if "return_lengths" in kwargs: - if verbose: - warnings.warn( - "The PreTrainedTokenizerBase.prepare_for_model `return_lengths` parameter is deprecated. " - "Please use `return_length` instead.", - FutureWarning, - ) - return_length = kwargs["return_lengths"] - - # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' - # padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( - # padding=padding, - # truncation=truncation, - # max_length=max_length, - # pad_to_multiple_of=pad_to_multiple_of, - # verbose=verbose, - # **kwargs, - # ) - # Load from model defaults - if return_token_type_ids is None: - return_token_type_ids = "token_type_ids" in self.model_input_names if return_attention_mask is None: return_attention_mask = "attention_mask" in self.model_input_names - encoded_inputs = {} - - if return_overflowing_tokens: - # currently, if drop_rows_to_fit is set to False and a table is too big, a ValueError is thrown - # see function _get_num_rows - raise ValueError("Overflowing tokens is currently not supported") - - if (answer_coordinates and not answer_texts) or (not answer_coordinates and answer_texts): - raise ValueError("In case you provide answers, both answer_coordinates and answer_text should be provided") - - add_loss_variables = None - if answer_coordinates is not None and answer_texts is not None: - assert len(answer_coordinates) == len(answer_texts) == len(queries) - add_loss_variables = True - - # First, tokenize the table and get the number of rows and columns - tokenized_table = self._tokenize_table(table) - num_rows = self._get_num_rows(table, self.drop_rows_to_fit) - num_columns = self._get_num_columns(table) + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(encoded_inputs["input_ids"]) - # Second, create the input ids for every table + query pair (and all the other features). This is a list of lists - features_examples = {} - position_to_label_ids = {} - for position, query in enumerate(queries): - if isinstance(query, str): - text_tokens = self.tokenize(query) - # currently, padding is done within the _to_trimmed_features function - serialized_example, features = self._to_trimmed_features( - question=query, - table=table, - question_tokens=text_tokens, - tokenized_table=tokenized_table, - num_columns=num_columns, - num_rows=num_rows, - drop_rows_to_fit=self.drop_rows_to_fit, - ) + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - if add_loss_variables: - column_ids = serialized_example.column_ids - row_ids = serialized_example.row_ids + needs_to_be_padded = ( + padding_strategy != PaddingStrategy.DO_NOT_PAD and len(encoded_inputs["input_ids"]) != max_length + ) - # create label ids from answer texts and coordinates - label_ids = self.get_answer_ids( - column_ids, - row_ids, - tokenized_table, - query, - answer_texts[position], - answer_coordinates[position], + if needs_to_be_padded: + difference = max_length - len(encoded_inputs["input_ids"]) + if self.padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [[self.pad_token_type_id] * 7] * difference ) - self._pad_to_seq_length(label_ids) - position_to_label_ids[position] = label_ids - features["label_ids"] = label_ids - - if position == 0: - prev_label_ids = [0] * len(features["input_ids"]) - else: - # TO DO: add prev label ids logic (see line 1118 in tf_example_utils.py) - prev_label_ids = position_to_label_ids[position - 1] - self._pad_to_seq_length(prev_label_ids) - features["prev_label_ids"] = prev_label_ids - - else: - prev_label_ids = [0] * len(features["input_ids"]) - self._pad_to_seq_length(prev_label_ids) - features["prev_label_ids"] = prev_label_ids - - features_examples[position] = features + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference + elif self.padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + [1] * len(encoded_inputs["input_ids"]) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [[self.pad_token_type_id] * 7] * difference + encoded_inputs[ + "token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"] else: - raise ValueError("Query is not valid. Should be a string.") - - # Build output dictionnary - encoded_inputs["input_ids"] = [features_examples[position]["input_ids"] for position in range(len(queries))] - encoded_inputs["attention_mask"] = [ - features_examples[position]["attention_mask"] for position in range(len(queries)) - ] - - token_types = [ - "segment_ids", - "column_ids", - "row_ids", - "prev_label_ids", - "column_ranks", - "inv_column_ranks", - "numeric_relations", - ] - token_type_ids = [] - for position in range(len(queries)): - token_type_ids_example = [] - for token_idx in range(self.model_max_length): - token_ids = [] - for type in token_types: - token_ids.append(features_examples[position][type][token_idx]) - token_type_ids_example.append(token_ids) - # token_type_ids_example is a list of seq_length elements, each element being a list of 7 elements - token_type_ids.append(token_type_ids_example) - - if return_token_type_ids: - encoded_inputs["token_type_ids"] = token_type_ids - - if add_loss_variables: - encoded_inputs["label_ids"] = [ - features_examples[position]["label_ids"] for position in range(len(queries)) - ] - encoded_inputs["numeric_values"] = [ - features_examples[position]["numeric_values"] for position in range(len(queries)) - ] - encoded_inputs["numeric_values_scale"] = [ - features_examples[position]["numeric_values_scale"] for position in range(len(queries)) - ] - # to do: add aggregation function id, classification class index and answer (or should people prepare this themselves?) - - if return_special_tokens_mask: - raise ValueError("Special tokens mask is currently not supported") - - if return_length: - encoded_inputs["length"] = len(encoded_inputs["input_ids"]) - - batch_outputs = BatchEncoding(encoded_inputs, tensor_type=return_tensors) + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + else: + if return_attention_mask: + encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"]) - return batch_outputs + return encoded_inputs #### Everything related to converting logits to predictions #### diff --git a/tests/test_tokenization_tapas.py b/tests/test_tokenization_tapas.py new file mode 100644 index 000000000000..baae26423406 --- /dev/null +++ b/tests/test_tokenization_tapas.py @@ -0,0 +1,1195 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import os +import shutil +import tempfile +import unittest +from typing import List, Tuple + +import pandas as pd + +from transformers import AddedToken +from transformers.testing_utils import require_tokenizers, slow +from transformers.tokenization_tapas import ( + VOCAB_FILES_NAMES, + BasicTokenizer, + TapasTokenizer, + WordpieceTokenizer, + _is_control, + _is_punctuation, + _is_whitespace, +) + +from .test_tokenization_common import TokenizerTesterMixin, filter_non_english + + +@require_tokenizers +class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = TapasTokenizer + test_rust_tokenizer = False + space_between_special_tokens = True + from_pretrained_filter = filter_non_english + + def get_table( + self, + tokenizer: TapasTokenizer, + length=5, + ): + toks = [tokenizer.decode([i], clean_up_tokenization_spaces=False) for i in range(len(tokenizer))] + + if length == 0: + data = {} + else: + data = {toks[0]: [toks[tok] for tok in range(1, length)]} + + table = pd.DataFrame.from_dict(data) + + return table + + def get_table_and_query( + self, + tokenizer: TapasTokenizer, + add_special_tokens: bool = True, + length=5, + ): + toks = [tokenizer.decode([i], clean_up_tokenization_spaces=False) for i in range(len(tokenizer))] + table = self.get_table(tokenizer, length=length - 3) + query = " ".join(toks[:3]) + + return table, query + + def get_clean_sequence( + self, + tokenizer: TapasTokenizer, + with_prefix_space=False, + max_length=20, + min_length=5, + empty_table: bool = False, + add_special_tokens: bool = True, + return_table_and_query: bool = False, + ): + + toks = [tokenizer.decode([i], clean_up_tokenization_spaces=False) for i in range(len(tokenizer))] + + if empty_table: + table = pd.DataFrame.from_dict({}) + query = " ".join(toks[:min_length]) + else: + data = {toks[0]: [toks[tok] for tok in range(1, min_length - 3)]} + table = pd.DataFrame.from_dict(data) + query = " ".join(toks[:3]) + + output_ids = tokenizer.encode(table, query, add_special_tokens=add_special_tokens) + output_txt = tokenizer.decode(output_ids) + + assert len(output_ids) >= min_length, "Update the code to generate the sequences so that they are larger" + assert len(output_ids) <= max_length, "Update the code to generate the sequences so that they are smaller" + + if return_table_and_query: + return output_txt, output_ids, table, query + + return output_txt, output_ids + + # def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5) -> Tuple[str, list]: + # data = { + # 'Actors': ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + # 'Age': ["56", "45", "59"], + # 'Number of movies': ["87", "53", "69"], + # 'Date of birth': ["18 december 1963", "11 november 1974", "6 may 1961"] + # } + # table = pd.DataFrame.from_dict(data) + # output_ids = tokenizer.encode(table, add_special_tokens=False, max_length=max_length) + # output_txt = tokenizer.decode(output_ids) + # + # return output_txt, output_ids + + def setUp(self): + super().setUp() + + vocab_tokens = [ + "[UNK]", + "[CLS]", + "[SEP]", + "[PAD]", + "[MASK]", + "want", + "##want", + "##ed", + "wa", + "un", + "runn", + "##ing", + ",", + "low", + "lowest", + ] + self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"]) + with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer: + vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) + + def get_input_output_texts(self, tokenizer): + input_text = "UNwant\u00E9d,running" + output_text = "unwanted, running" + return input_text, output_text + + def test_full_tokenizer(self): + tokenizer = self.tokenizer_class(self.vocab_file) + + tokens = tokenizer.tokenize("UNwant\u00E9d,running") + self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) + self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11]) + + def test_rust_and_python_full_tokenizers(self): + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + sequence = "UNwant\u00E9d,running" + + tokens = tokenizer.tokenize(sequence) + rust_tokens = rust_tokenizer.tokenize(sequence) + self.assertListEqual(tokens, rust_tokens) + + ids = tokenizer.encode(sequence, add_special_tokens=False) + rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False) + self.assertListEqual(ids, rust_ids) + + rust_tokenizer = self.get_rust_tokenizer() + ids = tokenizer.encode(sequence) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + + # With lower casing + tokenizer = self.get_tokenizer(do_lower_case=True) + rust_tokenizer = self.get_rust_tokenizer(do_lower_case=True) + + sequence = "UNwant\u00E9d,running" + + tokens = tokenizer.tokenize(sequence) + rust_tokens = rust_tokenizer.tokenize(sequence) + self.assertListEqual(tokens, rust_tokens) + + ids = tokenizer.encode(sequence, add_special_tokens=False) + rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False) + self.assertListEqual(ids, rust_ids) + + rust_tokenizer = self.get_rust_tokenizer() + ids = tokenizer.encode(sequence) + rust_ids = rust_tokenizer.encode(sequence) + self.assertListEqual(ids, rust_ids) + + def test_chinese(self): + tokenizer = BasicTokenizer() + + self.assertListEqual(tokenizer.tokenize("ah\u535A\u63A8zz"), ["ah", "\u535A", "\u63A8", "zz"]) + + def test_basic_tokenizer_lower(self): + tokenizer = BasicTokenizer(do_lower_case=True) + + self.assertListEqual( + tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["hello", "!", "how", "are", "you", "?"] + ) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + + def test_basic_tokenizer_lower_strip_accents_false(self): + tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hällo", "!", "how", "are", "you", "?"] + ) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"]) + + def test_basic_tokenizer_lower_strip_accents_true(self): + tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"] + ) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + + def test_basic_tokenizer_lower_strip_accents_default(self): + tokenizer = BasicTokenizer(do_lower_case=True) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"] + ) + self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"]) + + def test_basic_tokenizer_no_lower(self): + tokenizer = BasicTokenizer(do_lower_case=False) + + self.assertListEqual( + tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"] + ) + + def test_basic_tokenizer_no_lower_strip_accents_false(self): + tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HäLLo", "!", "how", "Are", "yoU", "?"] + ) + + def test_basic_tokenizer_no_lower_strip_accents_true(self): + tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True) + + self.assertListEqual( + tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"] + ) + + def test_basic_tokenizer_respects_never_split_tokens(self): + tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"]) + + self.assertListEqual( + tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"] + ) + + def test_wordpiece_tokenizer(self): + vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"] + + vocab = {} + for (i, token) in enumerate(vocab_tokens): + vocab[token] = i + tokenizer = WordpieceTokenizer(vocab=vocab, unk_token="[UNK]") + + self.assertListEqual(tokenizer.tokenize(""), []) + + self.assertListEqual(tokenizer.tokenize("unwanted running"), ["un", "##want", "##ed", "runn", "##ing"]) + + self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) + + def test_is_whitespace(self): + self.assertTrue(_is_whitespace(" ")) + self.assertTrue(_is_whitespace("\t")) + self.assertTrue(_is_whitespace("\r")) + self.assertTrue(_is_whitespace("\n")) + self.assertTrue(_is_whitespace("\u00A0")) + + self.assertFalse(_is_whitespace("A")) + self.assertFalse(_is_whitespace("-")) + + def test_is_control(self): + self.assertTrue(_is_control("\u0005")) + + self.assertFalse(_is_control("A")) + self.assertFalse(_is_control(" ")) + self.assertFalse(_is_control("\t")) + self.assertFalse(_is_control("\r")) + + def test_is_punctuation(self): + self.assertTrue(_is_punctuation("-")) + self.assertTrue(_is_punctuation("$")) + self.assertTrue(_is_punctuation("`")) + self.assertTrue(_is_punctuation(".")) + + self.assertFalse(_is_punctuation("A")) + self.assertFalse(_is_punctuation(" ")) + + def test_clean_text(self): + tokenizer = self.get_tokenizer() + # rust_tokenizer = self.get_rust_tokenizer() + + # Example taken from the issue https://github.com/huggingface/tokenizers/issues/340 + self.assertListEqual([tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]]) + + # self.assertListEqual( + # [rust_tokenizer.tokenize(t) for t in ["Test", "\xad", "test"]], [["[UNK]"], [], ["[UNK]"]] + # ) + + @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained("tapas-base-uncased") + + text = tokenizer.encode("sequence builders", add_special_tokens=False) + text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False) + + encoded_sentence = tokenizer.build_inputs_with_special_tokens(text) + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + + assert encoded_sentence == [101] + text + [102] + assert encoded_pair == [101] + text + [102] + text_2 + [102] + + def test_offsets_with_special_characters(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest("{} ({})".format(tokenizer.__class__.__name__, pretrained_name)): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + sentence = f"A, naïve {tokenizer_r.mask_token} AllenNLP sentence." + tokens = tokenizer_r.encode_plus( + sentence, + return_attention_mask=False, + return_token_type_ids=False, + return_offsets_mapping=True, + add_special_tokens=True, + ) + + do_lower_case = tokenizer_r.do_lower_case if hasattr(tokenizer_r, "do_lower_case") else False + expected_results = ( + [ + ((0, 0), tokenizer_r.cls_token), + ((0, 1), "A"), + ((1, 2), ","), + ((3, 5), "na"), + ((5, 6), "##ï"), + ((6, 8), "##ve"), + ((9, 15), tokenizer_r.mask_token), + ((16, 21), "Allen"), + ((21, 23), "##NL"), + ((23, 24), "##P"), + ((25, 33), "sentence"), + ((33, 34), "."), + ((0, 0), tokenizer_r.sep_token), + ] + if not do_lower_case + else [ + ((0, 0), tokenizer_r.cls_token), + ((0, 1), "a"), + ((1, 2), ","), + ((3, 8), "naive"), + ((9, 15), tokenizer_r.mask_token), + ((16, 21), "allen"), + ((21, 23), "##nl"), + ((23, 24), "##p"), + ((25, 33), "sentence"), + ((33, 34), "."), + ((0, 0), tokenizer_r.sep_token), + ] + ) + + self.assertEqual( + [e[1] for e in expected_results], tokenizer_r.convert_ids_to_tokens(tokens["input_ids"]) + ) + self.assertEqual([e[0] for e in expected_results], tokens["offset_mapping"]) + + def test_tapas_integration_test(self): + data = { + "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"], + "Age": ["56", "45", "59"], + "Number of movies": ["87", "53", "69"], + "Date of birth": ["18 december 1963", "11 november 1974", "6 may 1961"], + } + queries = [ + "When was Brad Pitt born?", + "Which actor appeared in the least number of movies?", + "What is the average number of movies?", + ] + table = pd.DataFrame.from_dict(data) + + # TODO: Should update this in the future + tokenizer = TapasTokenizer.from_pretrained("lysandre/tapas-temporary-repo", model_max_length=512) + + expected_results = { + "input_ids": [ + 101, + 2043, + 2001, + 8226, + 15091, + 2141, + 1029, + 102, + 5889, + 2287, + 2193, + 1997, + 5691, + 3058, + 1997, + 4182, + 8226, + 15091, + 5179, + 6584, + 2324, + 2285, + 3699, + 14720, + 4487, + 6178, + 9488, + 3429, + 5187, + 2340, + 2281, + 3326, + 2577, + 18856, + 7828, + 3240, + 5354, + 6353, + 1020, + 2089, + 3777, + ], + "attention_mask": [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + "token_type_ids": [ + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0], + [1, 2, 0, 0, 0, 0, 0], + [1, 3, 0, 0, 0, 0, 0], + [1, 3, 0, 0, 0, 0, 0], + [1, 3, 0, 0, 0, 0, 0], + [1, 4, 0, 0, 0, 0, 0], + [1, 4, 0, 0, 0, 0, 0], + [1, 4, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0], + [1, 2, 1, 0, 2, 2, 0], + [1, 3, 1, 0, 3, 1, 0], + [1, 4, 1, 0, 2, 2, 0], + [1, 4, 1, 0, 2, 2, 0], + [1, 4, 1, 0, 2, 2, 0], + [1, 1, 2, 0, 0, 0, 0], + [1, 1, 2, 0, 0, 0, 0], + [1, 1, 2, 0, 0, 0, 0], + [1, 1, 2, 0, 0, 0, 0], + [1, 2, 2, 0, 1, 3, 0], + [1, 3, 2, 0, 1, 3, 0], + [1, 4, 2, 0, 3, 1, 0], + [1, 4, 2, 0, 3, 1, 0], + [1, 4, 2, 0, 3, 1, 0], + [1, 1, 3, 0, 0, 0, 0], + [1, 1, 3, 0, 0, 0, 0], + [1, 1, 3, 0, 0, 0, 0], + [1, 1, 3, 0, 0, 0, 0], + [1, 2, 3, 0, 3, 1, 0], + [1, 3, 3, 0, 2, 2, 0], + [1, 4, 3, 0, 1, 3, 0], + [1, 4, 3, 0, 1, 3, 0], + [1, 4, 3, 0, 1, 3, 0], + ], + } + + new_encoded_inputs = tokenizer.encode_plus(table=table, query=queries[0], padding="max_length") + + self.assertDictEqual(new_encoded_inputs, expected_results) + + def test_add_special_tokens(self): + tokenizers: List[TapasTokenizer] = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + input_table = self.get_table(tokenizer, length=0) + + special_token = "[SPECIAL_TOKEN]" + + tokenizer.add_special_tokens({"cls_token": special_token}) + encoded_special_token = tokenizer.encode(input_table, special_token, add_special_tokens=False) + self.assertEqual(len(encoded_special_token), 1) + + decoded = tokenizer.decode(encoded_special_token, skip_special_tokens=True) + self.assertTrue(special_token not in decoded) + + def test_add_tokens_tokenizer(self): + tokenizers: List[TapasTokenizer] = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + table = self.get_table(tokenizer, length=0) + vocab_size = tokenizer.vocab_size + all_size = len(tokenizer) + + self.assertNotEqual(vocab_size, 0) + + # We usually have added tokens from the start in tests because our vocab fixtures are + # smaller than the original vocabs - let's not assert this + # self.assertEqual(vocab_size, all_size) + + new_toks = ["aaaaa bbbbbb", "cccccccccdddddddd"] + added_toks = tokenizer.add_tokens(new_toks) + vocab_size_2 = tokenizer.vocab_size + all_size_2 = len(tokenizer) + + self.assertNotEqual(vocab_size_2, 0) + self.assertEqual(vocab_size, vocab_size_2) + self.assertEqual(added_toks, len(new_toks)) + self.assertEqual(all_size_2, all_size + len(new_toks)) + + tokens = tokenizer.encode(table, "aaaaa bbbbbb low cccccccccdddddddd l", add_special_tokens=False) + + self.assertGreaterEqual(len(tokens), 4) + self.assertGreater(tokens[0], tokenizer.vocab_size - 1) + self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) + + new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"} + added_toks_2 = tokenizer.add_special_tokens(new_toks_2) + vocab_size_3 = tokenizer.vocab_size + all_size_3 = len(tokenizer) + + self.assertNotEqual(vocab_size_3, 0) + self.assertEqual(vocab_size, vocab_size_3) + self.assertEqual(added_toks_2, len(new_toks_2)) + self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) + + tokens = tokenizer.encode( + table, + ">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l", + add_special_tokens=False, + ) + + self.assertGreaterEqual(len(tokens), 6) + self.assertGreater(tokens[0], tokenizer.vocab_size - 1) + self.assertGreater(tokens[0], tokens[1]) + self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) + self.assertGreater(tokens[-2], tokens[-3]) + self.assertEqual(tokens[0], tokenizer.eos_token_id) + self.assertEqual(tokens[-2], tokenizer.pad_token_id) + + @require_tokenizers + def test_encode_decode_with_spaces(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + table = self.get_table(tokenizer, length=0) + + # new_toks = ["[ABC]", "[DEF]"] # TODO(thom) add this one back when Rust toks are ready: , "GHI IHG"] + new_toks = [AddedToken("[ABC]", normalized=False), AddedToken("[DEF]", normalized=False)] + tokenizer.add_tokens(new_toks) + input = "[ABC][DEF][ABC][DEF]" # TODO(thom) add back cf above: "[ABC] [DEF] [ABC] GHI IHG [DEF]" + if self.space_between_special_tokens: + output = "[ABC] [DEF] [ABC] [DEF]" + else: + output = input + encoded = tokenizer.encode(table, input, add_special_tokens=False) + decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens) + self.assertIn(decoded, [output, output.lower()]) + + def test_encode_plus_with_padding(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + table = self.get_table(tokenizer, length=0) + sequence = "Sequence" + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, sequence) + + padding_size = 10 + padding_idx = tokenizer.pad_token_id + token_type_padding_idx = tokenizer.pad_token_type_id + + encoded_sequence = tokenizer.encode_plus(table, sequence, return_special_tokens_mask=True) + input_ids = encoded_sequence["input_ids"] + special_tokens_mask = encoded_sequence["special_tokens_mask"] + sequence_length = len(input_ids) + + # Test 'longest' and 'no_padding' don't do anything + tokenizer.padding_side = "right" + + not_padded_sequence = tokenizer.encode_plus( + table, + sequence, + padding=True, + return_special_tokens_mask=True, + ) + not_padded_input_ids = not_padded_sequence["input_ids"] + + not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"] + not_padded_sequence_length = len(not_padded_input_ids) + + assert sequence_length == not_padded_sequence_length + assert input_ids == not_padded_input_ids + assert special_tokens_mask == not_padded_special_tokens_mask + + not_padded_sequence = tokenizer.encode_plus( + table, + sequence, + padding=False, + return_special_tokens_mask=True, + ) + not_padded_input_ids = not_padded_sequence["input_ids"] + + not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"] + not_padded_sequence_length = len(not_padded_input_ids) + + assert sequence_length == not_padded_sequence_length + assert input_ids == not_padded_input_ids + assert special_tokens_mask == not_padded_special_tokens_mask + + # Test right padding + tokenizer.padding_side = "right" + + right_padded_sequence = tokenizer.encode_plus( + table, + sequence, + max_length=sequence_length + padding_size, + padding="max_length", + return_special_tokens_mask=True, + ) + right_padded_input_ids = right_padded_sequence["input_ids"] + + right_padded_special_tokens_mask = right_padded_sequence["special_tokens_mask"] + right_padded_sequence_length = len(right_padded_input_ids) + + assert sequence_length + padding_size == right_padded_sequence_length + assert input_ids + [padding_idx] * padding_size == right_padded_input_ids + assert special_tokens_mask + [1] * padding_size == right_padded_special_tokens_mask + + # Test left padding + tokenizer.padding_side = "left" + left_padded_sequence = tokenizer.encode_plus( + table, + sequence, + max_length=sequence_length + padding_size, + padding="max_length", + return_special_tokens_mask=True, + ) + left_padded_input_ids = left_padded_sequence["input_ids"] + left_padded_special_tokens_mask = left_padded_sequence["special_tokens_mask"] + left_padded_sequence_length = len(left_padded_input_ids) + + assert sequence_length + padding_size == left_padded_sequence_length + assert [padding_idx] * padding_size + input_ids == left_padded_input_ids + assert [1] * padding_size + special_tokens_mask == left_padded_special_tokens_mask + + if "token_type_ids" in tokenizer.model_input_names: + token_type_ids = encoded_sequence["token_type_ids"] + left_padded_token_type_ids = left_padded_sequence["token_type_ids"] + right_padded_token_type_ids = right_padded_sequence["token_type_ids"] + + assert ( + token_type_ids + [[token_type_padding_idx] * 7] * padding_size == right_padded_token_type_ids + ) + assert [[token_type_padding_idx] * 7] * padding_size + token_type_ids == left_padded_token_type_ids + + if "attention_mask" in tokenizer.model_input_names: + attention_mask = encoded_sequence["attention_mask"] + right_padded_attention_mask = right_padded_sequence["attention_mask"] + left_padded_attention_mask = left_padded_sequence["attention_mask"] + + assert attention_mask + [0] * padding_size == right_padded_attention_mask + assert [0] * padding_size + attention_mask == left_padded_attention_mask + + def test_internal_consistency(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + table = self.get_table(tokenizer, length=0) + input_text, output_text = self.get_input_output_texts(tokenizer) + + tokens = tokenizer.tokenize(input_text) + ids = tokenizer.convert_tokens_to_ids(tokens) + ids_2 = tokenizer.encode(table, input_text, add_special_tokens=False) + self.assertListEqual(ids, ids_2) + + tokens_2 = tokenizer.convert_ids_to_tokens(ids) + self.assertNotEqual(len(tokens_2), 0) + text_2 = tokenizer.decode(ids) + self.assertIsInstance(text_2, str) + + self.assertEqual(text_2, output_text) + + def test_mask_output(self): + tokenizers = self.get_tokenizers(fast=False, do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + table, query = self.get_table_and_query(tokenizer) + + if ( + tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer" + and "token_type_ids" in tokenizer.model_input_names + ): + information = tokenizer.encode_plus(table, query, add_special_tokens=True) + sequences, mask = information["input_ids"], information["token_type_ids"] + self.assertEqual(len(sequences), len(mask)) + + @unittest.skip("TAPAS tokenizer only handles two sequences.") + def test_maximum_encoding_length_pair_input(self): + pass + + @unittest.skip("TAPAS tokenizer only handles two sequences.") + def test_maximum_encoding_length_single_input(self): + pass + + def test_number_of_added_tokens(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + + table, query = self.get_table_and_query(tokenizer) + + sequences = tokenizer.encode(table, query, add_special_tokens=False) + attached_sequences = tokenizer.encode(table, query, add_special_tokens=True) + + # Method is implemented (e.g. not GPT-2) + if len(attached_sequences) != 2: + self.assertEqual( + tokenizer.num_special_tokens_to_add(pair=True), len(attached_sequences) - len(sequences) + ) + + def test_padding_to_max_length(self): + """We keep this test for backward compatibility but it should be removed when `pad_to_max_length` will be deprecated""" + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + table = self.get_table(tokenizer) + sequence = "Sequence" + padding_size = 10 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, sequence) + + padding_idx = tokenizer.pad_token_id + + # Check that it correctly pads when a maximum length is specified along with the padding flag set to True + tokenizer.padding_side = "right" + encoded_sequence = tokenizer.encode(table, sequence) + sequence_length = len(encoded_sequence) + # FIXME: the next line should be padding(max_length) to avoid warning + padded_sequence = tokenizer.encode( + table, sequence, max_length=sequence_length + padding_size, pad_to_max_length=True + ) + padded_sequence_length = len(padded_sequence) + assert sequence_length + padding_size == padded_sequence_length + assert encoded_sequence + [padding_idx] * padding_size == padded_sequence + + # Check that nothing is done when a maximum length is not specified + encoded_sequence = tokenizer.encode(table, sequence) + sequence_length = len(encoded_sequence) + + tokenizer.padding_side = "right" + padded_sequence_right = tokenizer.encode(table, sequence, pad_to_max_length=True) + padded_sequence_right_length = len(padded_sequence_right) + assert sequence_length == padded_sequence_right_length + assert encoded_sequence == padded_sequence_right + + def test_padding_to_multiple_of(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + if tokenizer.pad_token is None: + self.skipTest("No padding token.") + else: + empty_tokens = tokenizer("", padding=True, pad_to_multiple_of=8) + normal_tokens = tokenizer("This is a sample input", padding=True, pad_to_multiple_of=8) + for key, value in empty_tokens.items(): + self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key)) + for key, value in normal_tokens.items(): + self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key)) + + normal_tokens = tokenizer("This", pad_to_multiple_of=8) + for key, value in normal_tokens.items(): + self.assertNotEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key)) + + # Should also work with truncation + normal_tokens = tokenizer("This", padding=True, truncation=True, pad_to_multiple_of=8) + for key, value in normal_tokens.items(): + self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key)) + + # truncation to something which is not a multiple of pad_to_multiple_of raises an error + self.assertRaises( + ValueError, + tokenizer.__call__, + "This", + padding=True, + truncation=True, + max_length=12, + pad_to_multiple_of=8, + ) + + def test_call(self): + # Tests that all call wrap to encode_plus and batch_encode_plus + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + sequences = [ + "Testing batch encode plus", + "Testing batch encode plus with different sequence lengths", + "Testing batch encode plus with different sequence lengths correctly pads", + ] + + # Test not batched + table = self.get_table(tokenizer, length=0) + encoded_sequences_1 = tokenizer.encode_plus(table, sequences[0]) + encoded_sequences_2 = tokenizer(table, sequences[0]) + self.assertEqual(encoded_sequences_1, encoded_sequences_2) + + # Test not batched pairs + table = self.get_table(tokenizer, length=10) + encoded_sequences_1 = tokenizer.encode_plus(table, sequences[1]) + encoded_sequences_2 = tokenizer(table, sequences[1]) + self.assertEqual(encoded_sequences_1, encoded_sequences_2) + + # Test batched + table = self.get_table(tokenizer, length=0) + encoded_sequences_1 = tokenizer.batch_encode_plus(table, sequences) + encoded_sequences_2 = tokenizer(table, sequences) + self.assertEqual(encoded_sequences_1, encoded_sequences_2) + + def test_batch_encode_plus_batch_sequence_length(self): + # Tests that all encoded values have the correct size + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + table = self.get_table(tokenizer, length=0) + sequences = [ + "Testing batch encode plus", + "Testing batch encode plus with different sequence lengths", + "Testing batch encode plus with different sequence lengths correctly pads", + ] + + encoded_sequences = [tokenizer.encode_plus(table, sequence) for sequence in sequences] + encoded_sequences_batch = tokenizer.batch_encode_plus(table, sequences, padding=False) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + maximum_length = len( + max([encoded_sequence["input_ids"] for encoded_sequence in encoded_sequences], key=len) + ) + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, sequences) + + encoded_sequences_padded = [ + tokenizer.encode_plus(table, sequence, max_length=maximum_length, padding="max_length") + for sequence in sequences + ] + + encoded_sequences_batch_padded = tokenizer.batch_encode_plus(table, sequences, padding=True) + self.assertListEqual( + encoded_sequences_padded, + self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch_padded), + ) + + # check 'longest' is unsensitive to a max length + encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus(table, sequences, padding=True) + encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus( + table, sequences, max_length=maximum_length + 10, padding="longest" + ) + for key in encoded_sequences_batch_padded_1.keys(): + self.assertListEqual( + encoded_sequences_batch_padded_1[key], + encoded_sequences_batch_padded_2[key], + ) + + # check 'no_padding' is unsensitive to a max length + encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus(table, sequences, padding=False) + encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus( + table, sequences, max_length=maximum_length + 10, padding=False + ) + for key in encoded_sequences_batch_padded_1.keys(): + self.assertListEqual( + encoded_sequences_batch_padded_1[key], + encoded_sequences_batch_padded_2[key], + ) + + def test_batch_encode_plus_overflowing_tokens(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + table = self.get_table(tokenizer, length=0) + string_sequences = ["Testing the prepare_for_model method.", "Test"] + + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + tokenizer.batch_encode_plus( + table, string_sequences, return_overflowing_tokens=True, truncation=True, padding=True, max_length=3 + ) + + def test_batch_encode_plus_padding(self): + # Test that padded sequences are equivalent between batch_encode_plus and encode_plus + + # Right padding tests + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + table = self.get_table(tokenizer, length=0) + sequences = [ + "Testing batch encode plus", + "Testing batch encode plus with different sequence lengths", + "Testing batch encode plus with different sequence lengths correctly pads", + ] + + max_length = 100 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, sequences) + + encoded_sequences = [ + tokenizer.encode_plus(table, sequence, max_length=max_length, padding="max_length") + for sequence in sequences + ] + encoded_sequences_batch = tokenizer.batch_encode_plus( + table, sequences, max_length=max_length, padding="max_length" + ) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + # Left padding tests + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + tokenizer.padding_side = "left" + sequences = [ + "Testing batch encode plus", + "Testing batch encode plus with different sequence lengths", + "Testing batch encode plus with different sequence lengths correctly pads", + ] + + max_length = 100 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, sequences) + + encoded_sequences = [ + tokenizer.encode_plus(table, sequence, max_length=max_length, padding="max_length") + for sequence in sequences + ] + encoded_sequences_batch = tokenizer.batch_encode_plus( + table, sequences, max_length=max_length, padding="max_length" + ) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + def test_padding_to_multiple_of(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + table = self.get_table(tokenizer, length=0) + if tokenizer.pad_token is None: + self.skipTest("No padding token.") + else: + empty_tokens = tokenizer(table, padding=True, pad_to_multiple_of=8) + normal_tokens = tokenizer(table, "This is a sample input", padding=True, pad_to_multiple_of=8) + for key, value in empty_tokens.items(): + self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key)) + for key, value in normal_tokens.items(): + self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key)) + + normal_tokens = tokenizer(table, "This", pad_to_multiple_of=8) + for key, value in normal_tokens.items(): + self.assertNotEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key)) + + # Should also work with truncation + normal_tokens = tokenizer(table, "This", padding=True, truncation=True, pad_to_multiple_of=8) + for key, value in normal_tokens.items(): + self.assertEqual(len(value) % 8, 0, "BatchEncoding.{} is not multiple of 8".format(key)) + + # truncation to something which is not a multiple of pad_to_multiple_of raises an error + self.assertRaises( + ValueError, + tokenizer.__call__, + table, + "This", + padding=True, + truncation=True, + max_length=12, + pad_to_multiple_of=8, + ) + + @unittest.skip("TAPAS cannot handle `prepare_for_model` without passing by `encode_plus` or `batch_encode_plus`") + def test_prepare_for_model(self): + pass + + def test_tokenizer_slow_store_full_signature(self): + signature = inspect.signature(self.tokenizer_class.__init__) + tokenizer = self.get_tokenizer() + + for parameter_name, parameter in signature.parameters.items(): + if parameter.default != inspect.Parameter.empty: + self.assertIn(parameter_name, tokenizer.init_kwargs) + + def test_special_tokens_mask_input_pairs(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + sequence_0 = "Encode this." + empty_table = self.get_table(tokenizer, length=0) + table = self.get_table(tokenizer, length=10) + encoded_sequence = tokenizer.encode(empty_table, sequence_0, add_special_tokens=False) + encoded_sequence += tokenizer.encode(table, "", add_special_tokens=False) + encoded_sequence_dict = tokenizer.encode_plus( + table, + sequence_0, + add_special_tokens=True, + return_special_tokens_mask=True, + # add_prefix_space=False, + ) + encoded_sequence_w_special = encoded_sequence_dict["input_ids"] + special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] + self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) + + filtered_sequence = [ + (x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special) + ] + filtered_sequence = [x for x in filtered_sequence if x is not None] + self.assertEqual(encoded_sequence, filtered_sequence) + + def test_special_tokens_mask(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + table = self.get_table(tokenizer, length=0) + sequence_0 = "Encode this." + # Testing single inputs + encoded_sequence = tokenizer.encode(table, sequence_0, add_special_tokens=False) + encoded_sequence_dict = tokenizer.encode_plus( + table, sequence_0, add_special_tokens=True, return_special_tokens_mask=True + ) + encoded_sequence_w_special = encoded_sequence_dict["input_ids"] + special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] + self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) + + filtered_sequence = [x for i, x in enumerate(encoded_sequence_w_special) if not special_tokens_mask[i]] + self.assertEqual(encoded_sequence, filtered_sequence) + + def test_save_and_load_tokenizer(self): + # safety check on max_len default value so we are sure the test works + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + self.assertNotEqual(tokenizer.model_max_length, 42) + + # Now let's start the test + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + # Isolate this from the other tests because we save additional tokens/etc + table = self.get_table(tokenizer, length=0) + tmpdirname = tempfile.mkdtemp() + + sample_text = " He is very happy, UNwant\u00E9d,running" + before_tokens = tokenizer.encode(table, sample_text, add_special_tokens=False) + before_vocab = tokenizer.get_vocab() + tokenizer.save_pretrained(tmpdirname) + + after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname) + after_tokens = after_tokenizer.encode(table, sample_text, add_special_tokens=False) + after_vocab = after_tokenizer.get_vocab() + self.assertListEqual(before_tokens, after_tokens) + self.assertDictEqual(before_vocab, after_vocab) + + shutil.rmtree(tmpdirname) + + def test_right_and_left_padding(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + table = self.get_table(tokenizer, length=0) + sequence = "Sequence" + padding_size = 10 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, sequence) + + padding_idx = tokenizer.pad_token_id + + # RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True + tokenizer.padding_side = "right" + encoded_sequence = tokenizer.encode(table, sequence) + sequence_length = len(encoded_sequence) + padded_sequence = tokenizer.encode( + table, sequence, max_length=sequence_length + padding_size, padding="max_length" + ) + padded_sequence_length = len(padded_sequence) + assert sequence_length + padding_size == padded_sequence_length + assert encoded_sequence + [padding_idx] * padding_size == padded_sequence + + # LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True + tokenizer.padding_side = "left" + encoded_sequence = tokenizer.encode(table, sequence) + sequence_length = len(encoded_sequence) + padded_sequence = tokenizer.encode( + table, sequence, max_length=sequence_length + padding_size, padding="max_length" + ) + padded_sequence_length = len(padded_sequence) + assert sequence_length + padding_size == padded_sequence_length + assert [padding_idx] * padding_size + encoded_sequence == padded_sequence + + # RIGHT & LEFT PADDING - Check that nothing is done for 'longest' and 'no_padding' + encoded_sequence = tokenizer.encode(table, sequence) + sequence_length = len(encoded_sequence) + + tokenizer.padding_side = "right" + padded_sequence_right = tokenizer.encode(table, sequence, padding=True) + padded_sequence_right_length = len(padded_sequence_right) + assert sequence_length == padded_sequence_right_length + assert encoded_sequence == padded_sequence_right + + tokenizer.padding_side = "left" + padded_sequence_left = tokenizer.encode(table, sequence, padding="longest") + padded_sequence_left_length = len(padded_sequence_left) + assert sequence_length == padded_sequence_left_length + assert encoded_sequence == padded_sequence_left + + tokenizer.padding_side = "right" + padded_sequence_right = tokenizer.encode(table, sequence) + padded_sequence_right_length = len(padded_sequence_right) + assert sequence_length == padded_sequence_right_length + assert encoded_sequence == padded_sequence_right + + tokenizer.padding_side = "left" + padded_sequence_left = tokenizer.encode(table, sequence, padding=False) + padded_sequence_left_length = len(padded_sequence_left) + assert sequence_length == padded_sequence_left_length + assert encoded_sequence == padded_sequence_left + + @unittest.skip("TAPAS doesn't handle pre-tokenized inputs.") + def test_pretokenized_inputs(self): + pass From eb0c17a3f838953eadcc60886debb58b08cb3b91 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 12 Nov 2020 17:07:27 -0500 Subject: [PATCH 2/2] Fix some of NielsRogge's comments --- src/transformers/tokenization_tapas.py | 48 +++++++++++--------------- 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/src/transformers/tokenization_tapas.py b/src/transformers/tokenization_tapas.py index 2484f9b38b48..1a55980e71e3 100644 --- a/src/transformers/tokenization_tapas.py +++ b/src/transformers/tokenization_tapas.py @@ -327,8 +327,7 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = return (vocab_file,) def create_attention_mask_from_sequences(self, query_ids: List[int], table_values: List[TableValue]) -> List[int]: - table_ids = list(zip(*table_values))[0] if table_values else [] - return [1] * (1 + len(query_ids) + 1) + [0] * len(table_ids) + return [1] * (1 + len(query_ids) + 1 + len(table_values)) def create_segment_token_type_ids_from_sequences( self, query_ids: List[int], table_values: List[TableValue] @@ -348,12 +347,6 @@ def create_row_token_type_ids_from_sequences( table_row_ids = list(zip(*table_values))[2] if table_values else [] return [0] * (1 + len(query_ids) + 1) + list(table_row_ids) - def create_label_ids_from_sequences_and_answers( - self, query_ids: List[int], table_values: List[TableValue] - ) -> List[int]: - table_row_ids = list(zip(*table_values))[2] if table_values else [] - return [0] * (1 + len(query_ids) + 1) + list(table_row_ids) - def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None ) -> List[int]: @@ -361,8 +354,6 @@ def build_inputs_with_special_tokens( Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. - This implementation does not add special tokens and this method should be overridden in a subclass. - Args: token_ids_0 (:obj:`List[int]`): The first tokenized sequence. token_ids_1 (:obj:`List[int]`, `optional`): The second tokenized sequence. @@ -411,6 +402,9 @@ def __call__( table: pd.DataFrame, queries: Optional[ Union[ + TextInput, + PreTokenizedInput, + EncodedInput, List[TextInput], List[PreTokenizedInput], List[EncodedInput], @@ -502,7 +496,7 @@ def __call__( return self.encode_plus( table=table, query=queries, - answer_coordinate=answer_coordinates, + answer_coordinates=answer_coordinates, answer_text=answer_texts, add_special_tokens=add_special_tokens, padding=padding, @@ -569,7 +563,7 @@ def batch_encode_plus( ) if (answer_coordinates and not answer_texts) or (not answer_coordinates and answer_texts): - raise ValueError("In case you provide answers, both answer_coordinate and answer_text should be provided") + raise ValueError("In case you provide answers, both answer_coordinates and answer_text should be provided") elif answer_coordinates is None and answer_texts is None: answer_coordinates = answer_texts = [None] * len(queries) @@ -695,9 +689,9 @@ def _batch_prepare_for_model( queries_ids: List[List[int]], raw_table: pd.DataFrame, raw_queries: Union[ - TextInput, - PreTokenizedInput, - EncodedInput, + List[TextInput], + List[PreTokenizedInput], + List[EncodedInput], ], answer_coordinates: Optional[List[Tuple]] = None, answer_texts: Optional[List[TextInput]] = None, @@ -741,7 +735,7 @@ def _batch_prepare_for_model( table_data = None queries_tokens = [None] * len(queries_ids) - for query_ids, raw_query, query_tokens, answer_coordinate, answer_text in zip( + for query_ids, raw_query, query_tokens, answer_coords, answer_text in zip( queries_ids, raw_queries, queries_tokens, answer_coordinates, answer_texts ): outputs = self.prepare_for_model( @@ -751,7 +745,7 @@ def _batch_prepare_for_model( raw_query, table_data=table_data, query_tokens=query_tokens, - answer_coordinate=answer_coordinate, + answer_coordinates=answer_coords, answer_text=answer_text, add_special_tokens=add_special_tokens, padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward @@ -828,7 +822,7 @@ def encode_plus( EncodedInput, ] ] = None, - answer_coordinate: Optional[List[Tuple]] = None, + answer_coordinates: Optional[List[Tuple]] = None, answer_text: Optional[List[TextInput]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, @@ -863,8 +857,8 @@ def encode_plus( "set return_token_type_ids to None." ) - if (answer_coordinate and not answer_text) or (not answer_coordinate and answer_text): - raise ValueError("In case you provide answers, both answer_coordinate and answer_text should be provided") + if (answer_coordinates and not answer_text) or (not answer_coordinates and answer_text): + raise ValueError("In case you provide answers, both answer_coordinates and answer_text should be provided") if "is_split_into_words" in kwargs: raise NotImplementedError("Currently TapasTokenizer only supports questions as strings.") @@ -914,7 +908,7 @@ def _encode_plus( PreTokenizedInput, EncodedInput, ], - answer_coordinate: Optional[List[Tuple]] = None, + answer_coordinates: Optional[List[Tuple]] = None, answer_text: Optional[List[TextInput]] = None, add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, @@ -961,7 +955,7 @@ def _encode_plus( query, table_data=table_data, query_tokens=query_tokens, - answer_coordinate=answer_coordinate, + answer_coordinates=answer_coordinates, answer_text=answer_text, add_special_tokens=add_special_tokens, padding=padding_strategy.value, @@ -989,7 +983,7 @@ def prepare_for_model( PreTokenizedInput, EncodedInput, ], - answer_coordinate: Optional[List[Tuple]] = None, + answer_coordinates: Optional[List[Tuple]] = None, answer_text: Optional[List[TextInput]] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, @@ -1089,9 +1083,9 @@ def prepare_for_model( attention_mask = self.create_attention_mask_from_sequences(query_ids, table_data) encoded_inputs["attention_mask"] = attention_mask - if answer_coordinate is not None and answer_text is not None: + if answer_coordinates is not None and answer_text is not None: label_ids = self.get_answer_ids( - column_ids, row_ids, table_data, query_tokens, answer_text, answer_coordinate + column_ids, row_ids, table_data, query_tokens, answer_text, answer_coordinates ) numeric_values = self._get_numeric_values(raw_table, column_ids, row_ids, columns_to_numeric_values) numeric_values_scale = self._get_numeric_values_scale(raw_table, column_ids, row_ids) @@ -1355,7 +1349,7 @@ def _get_cell_token_indexes(self, column_ids, row_ids, column_id, row_id): yield index def _get_numeric_column_ranks(self, column_ids, row_ids, table): - """Adds column ranks for all numeric columns.""" + """Returns column ranks for all numeric columns.""" ranks = [0] * len(column_ids) inv_ranks = [0] * len(column_ids) @@ -1416,7 +1410,7 @@ def _get_numeric_sort_key_fn(self, table_numeric_values, value): def _get_numeric_relations(self, question, column_ids, row_ids, table, columns_to_numeric_values): """ - Return numeric relations embeddings + Returns numeric relations embeddings Args: question: The question, numeric values are used.