Skip to content

Commit

Permalink
fix gptj could not jit.trace in GPU (#23317)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored May 24, 2023
1 parent b4698b7 commit 767e6b5
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def forward(
key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)

if is_torch_fx_proxy(position_ids):
if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
# The logic to conditionally copy to GPU could not be traced, so we do this
# every time in the torch.fx case
embed_positions = get_embed_positions(self.embed_positions, position_ids)
Expand Down

0 comments on commit 767e6b5

Please sign in to comment.