Skip to content

Commit

Permalink
Correct naming pegasus x (huggingface#18896)
Browse files Browse the repository at this point in the history
* add first generation tutorial

* [Pegasus X] correct naming

* [Generation] Remove
  • Loading branch information
patrickvonplaten authored and oneraghavan committed Sep 26, 2022
1 parent cc6a1c4 commit 42e9c3a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/models/pegasus_x/test_modeling_pegasus_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def default_tokenizer(self):
return PegasusTokenizer.from_pretrained("google/pegasus-x-base")

def test_inference_no_head(self):
model = PegasusXModel.from_pretrained("pegasus-x-base").to(torch_device)
model = PegasusXModel.from_pretrained("google/pegasus-x-base").to(torch_device)
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
decoder_input_ids = _long_tensor([[2, 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588]])
inputs_dict = prepare_pegasus_x_inputs_dict(model.config, input_ids, decoder_input_ids)
Expand All @@ -574,7 +574,7 @@ def test_inference_no_head(self):
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))

def test_inference_head(self):
model = PegasusXForConditionalGeneration.from_pretrained("pegasus-x-base").to(torch_device)
model = PegasusXForConditionalGeneration.from_pretrained("google/pegasus-x-base").to(torch_device)

# change to intended input
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
Expand Down

0 comments on commit 42e9c3a

Please sign in to comment.