diff --git a/tests/nightly/gpu/test_bert.py b/tests/nightly/gpu/test_bert.py index e4a4de3061c..33e5c635fb2 100644 --- a/tests/nightly/gpu/test_bert.py +++ b/tests/nightly/gpu/test_bert.py @@ -17,47 +17,36 @@ class TestBertModel(unittest.TestCase): samples on convai2 """ - @testing_utils.retry(ntries=3, log_retry=True) def test_biencoder(self): valid, test = testing_utils.train_model( dict( - task='convai2', + task='integration_tests:overfit', model='bert_ranker/bi_encoder_ranker', - num_epochs=0.1, - batchsize=8, - learningrate=3e-4, - text_truncate=32, - validation_max_exs=20, - short_final_eval=True, + max_train_steps=500, + batchsize=2, + candidates="inline", + gradient_clip=1.0, + learningrate=1e-3, + text_truncate=8, ) ) - # can't conclude much from the biencoder after that little iterations. - # this test will just make sure it hasn't crashed and the accuracy isn't - # too high - self.assertLessEqual(test['accuracy'], 0.5) + self.assertGreaterEqual(test['accuracy'], 0.5) - @testing_utils.retry(ntries=3, log_retry=True) def test_crossencoder(self): valid, test = testing_utils.train_model( dict( - task='convai2', + task='integration_tests:overfit', model='bert_ranker/cross_encoder_ranker', - num_epochs=0.002, - batchsize=1, - candidates="inline", + max_train_steps=500, + batchsize=2, + learningrate=1e-3, + gradient_clip=1.0, type_optimization="all_encoder_layers", - warmup_updates=100, - text_truncate=32, - label_truncate=32, - validation_max_exs=20, - short_final_eval=True, + text_truncate=8, + label_truncate=8, ) ) - # The cross encoder reaches an interesting state MUCH faster - # accuracy should be present and somewhere between 0.2 and 0.8 - # (large interval so that it doesn't flake.) - self.assertGreaterEqual(test['accuracy'], 0.03) - self.assertLessEqual(test['accuracy'], 0.8) + self.assertGreaterEqual(test['accuracy'], 0.8) if __name__ == '__main__':