Skip to content

Commit

Permalink
Move test under integration b/c relies on T5
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Feb 17, 2023
1 parent 4bc5321 commit 2658bef
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def setUp(self) -> None:
def test_greedy_generate_with_t5(self) -> None:
generation_model = GenerationUtils(self.model)

tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30)
tokens = generation_model.generate(self.inputs, num_beams=1, max_length=30)
generated_text = self.transform.decode(tokens.tolist())

expected_generated_text = [
Expand Down
4 changes: 2 additions & 2 deletions torchtext/prototype/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class GenerationUtils:

def __init__(self, model: nn.Module, **kwargs) -> None:
self.model = model
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder")
self.is_huggingface_model = kwargs.pop("is_huggingface_model")
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True)
self.is_huggingface_model = kwargs.pop("is_huggingface_model", False)

def _prepare_decoder_ids_for_generation(
self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, **model_kwargs
Expand Down

0 comments on commit 2658bef

Please sign in to comment.