Skip to content

Commit

Permalink
Construct EarlyFusion's encoder_token_ids on correct device (pytorch#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers authored Jan 17, 2025
1 parent 890deab commit 7747db1
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion torchtune/modules/model_fusion/_early_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def reset_caches(self):

def _decoder_embed(self, tokens) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed the text-only tokens with the decoder's tok_embeddings"""
encoder_token_ids = torch.tensor(list(self.encoder_tokens.values()))
encoder_token_ids = torch.tensor(
list(self.encoder_tokens.values()), device=tokens.device
)
# [bsz, seq_len], True indicates the token is not an encoder special token
is_text = ~torch.isin(tokens, encoder_token_ids)
text_tokens = torch.masked_select(tokens, is_text)
Expand Down

0 comments on commit 7747db1

Please sign in to comment.