Skip to content

Commit

Permalink
decoder_input_ids -> forced_bos_token_id
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Mar 5, 2021
1 parent 0e25339 commit 5eefa4b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
5 changes: 1 addition & 4 deletions src/transformers/models/mbart/tokenization_mbart50_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,7 @@ def _build_translation_inputs(self, raw_inputs, src_lang: Optional[str], tgt_lan
self.src_lang = src_lang
inputs = self(raw_inputs, add_special_tokens=True, return_tensors="pt", truncation=truncation)
tgt_lang_id = self.convert_tokens_to_ids(tgt_lang)
import torch

inputs["decoder_input_ids"] = torch.LongTensor([self.bos_token_id, tgt_lang_id])
print(inputs)
inputs["forced_bos_token_id"] = tgt_lang_id
return inputs

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,10 @@ def ensure_tensor_on_device(self, **inputs):
Return:
:obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
"""
return {name: tensor.to(self.device) for name, tensor in inputs.items()}
return {
name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor
for name, tensor in inputs.items()
}

def check_model_type(self, supported_models: Union[List[str], dict]):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/pipelines/text2text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def _generate(
max_length = generate_kwargs.get("max_length", self.model.config.max_length)
self.check_inputs(input_length, min_length, max_length)

generate_kwargs.update(inputs)

generations = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**generate_kwargs,
)
results = []
Expand Down
8 changes: 4 additions & 4 deletions tests/test_pipelines_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def test_multilingual_translation(self):
translator("This is a test")

outputs = translator("This is a test", src_lang="en_XX", tgt_lang="ar_AR")
self.assertEqual(outputs, [{"translation_text": "یہ ایک امتحان ہے"}])
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])

outputs = translator("This is a test", src_lang="en_XX", tgt_lang="hi_IN")
self.assertEqual(outputs, [{"translation_text": "This is a test."}])
self.assertEqual(outputs, [{"translation_text": "यह एक परीक्षण है"}])

# src_lang, tgt_lang can be defined at call time
# src_lang, tgt_lang can be defined at pipeline call time
translator = pipeline(task="translation", model=model, tokenizer=tokenizer, src_lang="en_XX", tgt_lang="ar_AR")
outputs = translator("This is a test")
self.assertEqual(outputs, [{"translation_text": "یہ ایک امتحان ہے"}])
self.assertEqual(outputs, [{"translation_text": "هذا إختبار"}])

@require_torch
def test_translation_on_odd_language(self):
Expand Down

0 comments on commit 5eefa4b

Please sign in to comment.