2929 import torch
3030
3131 from transformers import (
32+ AutoTokenizer ,
3233 GPT2TokenizerFast ,
3334 GPTBigCodeForCausalLM ,
3435 GPTBigCodeForSequenceClassification ,
@@ -510,7 +511,7 @@ def test_generate_simple(self):
510511 output_sequence = model .generate (input_ids )
511512 output_sentence = tokenizer .decode (output_sequence [0 ], skip_special_tokens = True )
512513
513- expected_output = """ def print_hello_world():\n print("Hello World!")\n \n \n def print_hello_"""
514+ expected_output = ' def print_hello_world():\n print("Hello World!")\n \n \n def print_hello_world_with_args(name' # fmt: skip
514515 self .assertEqual (output_sentence , expected_output )
515516
516517 def test_generate_batched (self ):
@@ -527,11 +528,27 @@ def test_generate_batched(self):
527528 outputs = tokenizer .batch_decode (outputs , skip_special_tokens = True )
528529
529530 expected_output = [
530- 'def print_hello_world():\n print("Hello World!")\n \n \n def print_hello_ ' ,
531- 'def say_hello():\n print("Hello, World!")\n \n \n say_hello()' ,
531+ 'def print_hello_world():\n print("Hello World!")\n \n \n def print_hello_world_with_args(name ' ,
532+ 'def say_hello():\n print("Hello, World!")\n \n \n say_hello()\n ' ,
532533 ]
533534 self .assertListEqual (outputs , expected_output )
534535
536+ def test_newline_regression (self ):
537+ """Added to prevent regressions regarding attention (scaling) indicated by excessive newlines"""
538+ tokenizer = AutoTokenizer .from_pretrained ("bigcode/tiny_starcoder_py" )
539+ model = GPTBigCodeForCausalLM .from_pretrained ("bigcode/tiny_starcoder_py" ).to (torch_device )
540+
541+ input_ids = tokenizer (
542+ "Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.\n " ,
543+ return_tensors = "pt" ,
544+ ).input_ids .to (torch_device )
545+
546+ output_sequence = model .generate (input_ids , max_new_tokens = 20 , do_sample = False )
547+ output_sentence = tokenizer .decode (output_sequence [0 ], skip_special_tokens = True )
548+
549+ expected_output = 'Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.\n \n The impact of the COVID-19 pandemic on global economic structures and future business' # fmt: skip
550+ self .assertEqual (output_sentence , expected_output )
551+
535552
536553@require_torch
537554class GPTBigCodeMQATest (unittest .TestCase ):
0 commit comments