Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BigBird] fix bigbird slow tests #11109

Merged
merged 1 commit into from
Apr 7, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions tests/test_modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def test_tokenizer_inference(self):
model.to(torch_device)

text = [
'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth ... This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth ,, I was born in 92000, and this is falsé.'
"Transformer-based models are unable to process long sequences due to their self-attention operation, which scales quadratically with the sequence length. To address this limitation, we introduce the Longformer with an attention mechanism that scales linearly with sequence length, making it easy to process documents of thousands of tokens or longer. Longformer’s attention mechanism is a drop-in replacement for the standard self-attention and combines a local windowed attention with a task motivated global attention. Following prior work on long-sequence transformers, we evaluate Longformer on character-level language modeling and achieve state-of-the-art results on text8 and enwik8. In contrast to most prior work, we also pretrain Longformer and finetune it on a variety of downstream tasks. Our pretrained Longformer consistently outperforms RoBERTa on long document tasks and sets new state-of-the-art results on WikiHop and TriviaQA."
]
inputs = tokenizer(text)

Expand All @@ -798,22 +798,22 @@ def test_tokenizer_inference(self):
prediction = model(**inputs)
prediction = prediction[0]

self.assertEqual(prediction.shape, torch.Size((1, 128, 768)))
self.assertEqual(prediction.shape, torch.Size((1, 199, 768)))

expected_prediction = torch.tensor(
[
[-0.0745, 0.0689, -0.1126, -0.0610],
[-0.0343, 0.0111, -0.0269, -0.0858],
[0.1150, 0.0896, 0.0492, 0.0149],
[-0.0657, 0.2035, 0.0444, -0.0535],
[0.1143, 0.0465, 0.1583, -0.1855],
[-0.0216, 0.0807, 0.0536, 0.1371],
[-0.1879, 0.0097, -0.1916, 0.1701],
[0.7616, 0.1240, 0.0669, 0.2588],
[0.1096, -0.1810, -0.1987, 0.0445],
[0.1810, -0.3608, -0.0081, 0.1764],
[-0.0472, 0.0460, 0.0976, -0.0021],
[-0.0274, -0.3274, -0.0788, 0.0465],
[-0.0213, -0.2213, -0.0061, 0.0687],
[0.0977, 0.1858, 0.2374, 0.0483],
[0.2112, -0.2524, 0.5793, 0.0967],
[0.2473, -0.5070, -0.0630, 0.2174],
[0.2885, 0.1139, 0.6071, 0.2991],
[0.2328, -0.2373, 0.3648, 0.1058],
[0.2517, -0.0689, 0.0555, 0.0880],
[0.1021, -0.1495, -0.0635, 0.1891],
[0.0591, -0.0722, 0.2243, 0.2432],
[-0.2059, -0.2679, 0.3225, 0.6183],
[0.2280, -0.2618, 0.1693, 0.0103],
[0.0183, -0.1375, 0.2284, -0.1707],
],
device=torch_device,
)
Expand All @@ -826,19 +826,19 @@ def test_inference_question_answering(self):
)
model.to(torch_device)

context = "🤗 Transformers (formerly known as pytorch-transformers and pytorch-pretrained-bert) provides general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet…) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch. Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a question answering dataset is the SQuAD dataset"
context = "The BigBird model was proposed in Big Bird: Transformers for Longer Sequences by Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon, Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it has been shown that applying sparse, global, and random attention approximates full attention, while being computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context, BigBird has shown improved performance on various long document NLP tasks, such as question answering and summarization, compared to BERT or RoBERTa."

question = [
"How many pretrained models are available in 🤗 Transformers?",
"🤗 Transformers provides interoperability between which frameworks?",
"Which is better for longer sequences- BigBird or BERT?",
"What is the benefit of using BigBird over BERT?",
]
inputs = tokenizer(
question,
[context, context],
padding=True,
return_tensors="pt",
add_special_tokens=True,
max_length=128,
max_length=256,
truncation=True,
)

Expand All @@ -848,11 +848,11 @@ def test_inference_question_answering(self):

# fmt: off
target_start_logits = torch.tensor(
[[-9.5889, -10.2121, -14.2158, -11.1457, -10.7376, -7.3907, -10.2084, -9.5659, -15.0336, -8.6686, -9.1737, -11.1457, -13.4722, -6.3336, -9.6311, -8.4821, -15.141, -9.1226, -10.3328, -11.1457, -6.6793, -3.9627, 2.7126, -5.5607, -8.4625, -12.499, -11.4757, -9.6334, -4.0565, -10.0474, -7.4126, -13.5669], [-15.3796, -12.6863, -10.3951, -7.6706, -10.1808, -11.4401, -15.5868, -12.7959, -11.0186, -12.6863, -14.2198, -8.1182, -11.1353, -11.6512, -15.702, -12.8964, -12.5173, -12.6863, -14.4133, -13.1532, -12.2846, -14.1572, -11.2747, -11.1159, -11.5219, -13.1115, -11.8779, -13.989, -11.5234, -15.0459, -10.0178, -12.9253]], # noqa: E231
[[-8.9304, -10.3849, -14.4997, -9.6497, -13.9469, -7.8134, -8.9687, -13.3585, -9.7987, -13.8869, -9.2632, -8.9294, -13.6721, -7.3198, -9.5434, -11.2641, -14.3245, -9.5705, -12.7367, -8.6168, -11.083, -13.7573, -8.1151, -14.5329, -7.6876, -15.706, -12.8558, -9.1135, 8.0909, -3.1925, -11.5812, -9.4822], [-11.5595, -14.5591, -10.2978, -14.8445, -10.2092, -11.1899, -13.8356, -10.5644, -14.7706, -9.9841, -11.0052, -14.1862, -8.8173, -11.1098, -12.4686, -15.0531, -11.0196, -13.6614, -10.0236, -11.8151, -14.8744, -9.5123, -15.1605, -8.6472, -15.4184, -8.898, -9.6328, -7.0258, -11.3365, -14.4065, -10.2587, -8.9103]], # noqa: E231
device=torch_device,
)
target_end_logits = torch.tensor(
[[-12.4895, -10.9826, -13.8226, -11.9922, -13.2647, -12.4584, -10.6143, -9.4091, -16.844, -14.0393, -9.5914, -11.9922, -15.5142, -11.4073, -10.1064, -8.3961, -16.4374, -13.9323, -10.791, -11.9922, -8.736, -9.5672, 0.2844, -4.0976, -13.849, -11.8035, -12.7784, -14.1314, -7.4138, -10.5488, -8.0133, -14.8779], [-14.9831, -13.4818, -13.1566, -12.7259, -10.5892, -10.8605, -17.2376, -15.9398, -12.8739, -13.4818, -16.6979, -13.3403, -11.6416, -11.392, -16.9553, -15.723, -13.2643, -13.4818, -16.2067, -15.6688, -15.0449, -15.1253, -15.1373, -12.385, -13.3652, -15.9473, -14.9587, -15.5024, -13.1482, -16.6358, -12.3908, -15.7493]], # noqa: E231
[[-12.4131, -8.5959, -15.7163, -11.1524, -15.9913, -12.2038, -7.8902, -16.0296, -12.164, -16.5017, -13.3332, -6.9488, -15.7756, -13.8506, -11.0779, -9.2893, -15.0426, -10.1963, -17.3292, -12.2945, -11.5337, -16.4514, -9.1564, -17.5001, -9.1562, -16.2971, -13.3199, -7.5724, -5.1175, 7.2168, -10.3804, -11.9873], [-10.8654, -14.9967, -11.4144, -16.9189, -14.2673, -9.7068, -15.0182, -12.8846, -16.8716, -13.665, -10.3113, -15.1436, -14.9069, -13.3364, -11.2339, -16.0118, -11.8331, -17.0613, -13.8852, -12.4163, -16.8978, -10.7772, -17.2324, -10.6979, -16.9811, -10.3427, -9.497, -13.7104, -11.1107, -13.2936, -13.855, -14.1264]], # noqa: E231
device=torch_device,
)
# fmt: on
Expand All @@ -867,7 +867,7 @@ def test_inference_question_answering(self):
]
answer = tokenizer.batch_decode(answer)

self.assertTrue(answer == ["32", "[SEP]"])
self.assertTrue(answer == ["BigBird", "global attention"])

def test_fill_mask(self):
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
Expand Down