Skip to content

Commit

Permalink
Improve implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Feb 21, 2022
1 parent 387e52d commit fd1db5e
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 30 deletions.
10 changes: 5 additions & 5 deletions src/transformers/models/markuplm/processing_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ class MarkupLMProcessor(ProcessorMixin):
[`MarkupLMProcessor`] offers all the functionalities you need to prepare data for the model.
It first uses [`MarkupLMFeatureExtractor`] to get nodes and corresponding xpaths from one or more HTML strings. Next,
these are provided to [`MarkupLMTokenizer`], which turns them into token-level `input_ids`, `attention_mask`,
these are provided to [`MarkupLMTokenizer`] or [`MarkupLMTokenizerFast`], which turns them into token-level `input_ids`, `attention_mask`,
`token_type_ids`, `xpath_tags_seq` and `xpath_subs_seq`.
Args:
feature_extractor (`MarkupLMFeatureExtractor`):
An instance of [`MarkupLMFeatureExtractor`]. The feature extractor is a required input.
tokenizer (`MarkupLMTokenizer`):
An instance of [`MarkupLMTokenizer`]. The tokenizer is a required input.
tokenizer (`MarkupLMTokenizer` or `MarkupLMTokenizerFast`):
An instance of [`MarkupLMTokenizer`] or [`MarkupLMTokenizerFast`]. The tokenizer is a required input.
"""
feature_extractor_class = "MarkupLMFeatureExtractor"
tokenizer_class = ("MarkupLMTokenizer")
tokenizer_class = ("MarkupLMTokenizer", "MarkupLMTokenizerFast")

def __call__(
self,
Expand Down Expand Up @@ -95,4 +95,4 @@ def __call__(
**kwargs,
)

return encoded_inputs
return encoded_inputs
36 changes: 21 additions & 15 deletions src/transformers/models/markuplm/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
from transformers import MarkupLMTokenizer, MarkupLMTokenizerFast
import torch

slow_tokenizer = MarkupLMTokenizer.from_pretrained("microsoft/markuplm-base")
from transformers import MarkupLMTokenizer, MarkupLMTokenizerFast, LayoutLMv2TokenizerFast

slow_encoding = slow_tokenizer(
["hello", "world"],
xpaths=["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"],
padding="max_length",
max_length=20,
return_tensors="pt",
)

# slow_tokenizer = MarkupLMTokenizer.from_pretrained("microsoft/markuplm-base")

# slow_encoding = slow_tokenizer(
# ["hello", "world"],
# xpaths=["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"],
# padding="max_length",
# max_length=20,
# return_tensors="pt",
# )

fast_tokenizer = MarkupLMTokenizerFast.from_pretrained("microsoft/markuplm-base")

fast_encoding = fast_tokenizer(
["hello", "world"],
xpaths=["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"],
fast_tokenizer_bis = LayoutLMv2TokenizerFast.from_pretrained("microsoft/layoutlmv2-base-uncased")

fast_encoding = fast_tokenizer_bis(
["hello", "world", "how", "are", "you"],
#xpaths=["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span", "html/body", "html/body/div"],
nodes=[[1,2,3,4] for _ in range(5)],
padding="max_length",
max_length=20,
max_length=2,
return_tensors="pt",
return_overflowing_tokens=True,
)

for k in slow_encoding.keys():
assert torch.allclose(slow_encoding[k], fast_encoding[k])
# for k in slow_encoding.keys():
# assert torch.allclose(slow_encoding[k], fast_encoding[k])
10 changes: 9 additions & 1 deletion src/transformers/models/markuplm/tokenization_markuplm_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,9 @@ def _batch_encode_plus(

if not isinstance(batch_text_or_text_pairs, list):
raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})")


print("Max length:", max_length)

# Set the truncation and padding strategy and restore the initial configuration
self.set_truncation_and_padding(
padding_strategy=padding_strategy,
Expand All @@ -595,12 +597,18 @@ def _batch_encode_plus(
if is_pair:
batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs]

print("Batch text or text pairs:", batch_text_or_text_pairs)

encodings = self._tokenizer.encode_batch(
batch_text_or_text_pairs,
add_special_tokens=add_special_tokens,
is_pretokenized=True, # we set this to True as MarkupLM always expects pretokenized inputs
)

print("Encodings:", encodings)
print(len(encodings))
print(encodings[0].overflowing)

# Convert encoding to dict
# `Tokens` has type: Tuple[
# List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
Expand Down
15 changes: 9 additions & 6 deletions tests/test_processor_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@
import unittest
from typing import List

from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers import (
MarkupLMFeatureExtractor,
MarkupLMProcessor,
MarkupLMTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
)
from transformers.file_utils import FEATURE_EXTRACTOR_NAME, cached_property
from transformers import MarkupLMTokenizer
from transformers.models.markuplm.tokenization_markuplm import VOCAB_FILES_NAMES
from transformers.testing_utils import require_torch, slow


from transformers import MarkupLMFeatureExtractor, MarkupLMProcessor


@require_tokenizers
class MarkupLMProcessorTest(unittest.TestCase):
tokenizer_class = MarkupLMTokenizer
Expand Down Expand Up @@ -402,4 +405,4 @@ def test_processor_case_5(self):

# verify bbox
expected_bbox = [[6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3], [1, 1, 2, 3], [1000, 1000, 1000, 1000]]
self.assertListEqual(input_processor.bbox[1].tolist()[-5:], expected_bbox)
self.assertListEqual(input_processor.bbox[1].tolist()[-5:], expected_bbox)
13 changes: 10 additions & 3 deletions tests/test_tokenization_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,7 @@ def test_training_new_tokenizer(self):

# Test we can use the new tokenizer with something not seen during training
text = [["this", "is", "the"], ["how", "are", "you"]]
xpaths = [["html/body"]*3, ["html/body"]*3]
xpaths = [["html/body"] * 3, ["html/body"] * 3]
inputs = new_tokenizer(text, xpaths=xpaths)
self.assertEqual(len(inputs["input_ids"]), 2)
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
Expand Down Expand Up @@ -1510,7 +1510,7 @@ def test_training_new_tokenizer_with_special_tokens_change(self):

# Test we can use the new tokenizer with something not seen during training
nodes = [["this", "is"], ["hello", "🤗"]]
xpaths = [["html/body"]*2, ["html/body"]*2]
xpaths = [["html/body"] * 2, ["html/body"] * 2]
inputs = new_tokenizer(nodes, xpaths=xpaths)
self.assertEqual(len(inputs["input_ids"]), 2)
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
Expand Down Expand Up @@ -1595,16 +1595,23 @@ def test_batch_encode_dynamic_overflowing(self):

# Single example
nodes, xpaths = self.get_nodes_and_xpaths()
print("Nodes:", nodes)
print("Xpaths:", xpaths)
tokens = tokenizer.encode_plus(
nodes,
xpaths=xpaths,
max_length=6,
max_length=1,
padding=True,
truncation=True,
return_tensors=returned_tensor,
return_overflowing_tokens=True,
)

for k, v in tokens.items():
print(k, v.shape)

print(tokenizer.decode(tokens.input_ids.squeeze().tolist()))

for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
self.assertEqual(len(tokens[key].shape), 2)

Expand Down

0 comments on commit fd1db5e

Please sign in to comment.