diff --git a/torchtext/prototype/models/t5/model.py b/torchtext/prototype/models/t5/model.py index 7db6431d06..127aa44ee2 100644 --- a/torchtext/prototype/models/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -85,7 +85,6 @@ def __init__( self.padding_idx = config.padding_idx self.training = config.training self.dropout = config.dropout if config.training else 0.0 - self.device = device self.dtype = dtype self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, config.padding_idx) @@ -184,13 +183,16 @@ def forward( # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx. if decoder_tokens is None: - decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * self.padding_idx + decoder_tokens = ( + torch.ones((encoder_tokens.size(0), 1), device=encoder_tokens.device, dtype=torch.long) + * self.padding_idx + ) if decoder_mask is None: assert decoder_tokens is not None and decoder_tokens.dim() == 2 tgt_len = decoder_tokens.shape[1] decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1) - decoder_mask = decoder_mask.to(self.device, dtype=torch.bool) + decoder_mask = decoder_mask.to(decoder_tokens.device, dtype=torch.bool) decoder_padding_mask = decoder_tokens.eq(self.padding_idx) # T5 implemention uses padding idx to start sequence. Want to ignore this when masking diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index 63ec17170d..7053f42cf5 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -74,8 +74,6 @@ def __init__( else: self.relative_attention_bias = None - self.device = device - def forward( self, query: Tensor, @@ -257,9 +255,7 @@ def _t5_multi_head_attention_forward( ).unsqueeze(0) else: position_bias = self._compute_bias( - tgt_len, - src_len, - bidirectional=(not self.is_decoder), + tgt_len, src_len, bidirectional=(not self.is_decoder), device=k.device ) # Calculate attention and out projection @@ -405,15 +401,12 @@ def _t5_dot_product_attention( # NOTE: Modified from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L421 def _compute_bias( - self, - query_length: int, - key_length: int, - bidirectional: bool = True, + self, query_length: int, key_length: int, bidirectional: bool = True, device: Optional[torch.device] = None ) -> Tensor: """Compute binned relative position bias""" assert self.relative_attention_bias is not None - context_position = torch.arange(query_length, dtype=torch.long, device=self.device)[:, None] - memory_position = torch.arange(key_length, dtype=torch.long, device=self.device)[None, :] + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) @@ -446,7 +439,7 @@ def _relative_position_bucket( Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ - relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long, device=self.device) + relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long, device=relative_position.device) if bidirectional: num_buckets = num_buckets // 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets