Skip to content

Commit

Permalink
Compute the mask in-place, with less memory reads, and on CUDA on `XL…
Browse files Browse the repository at this point in the history
…NetLMHeadModel` (huggingface#23332)

When working on TorchInductor, I realised that there was a part from
`XLNetLMHeadModel` that was being compiled to CPU code.

This PR should allow to fuse this operation with other CUDA operations
in `torch.compile`. It also should be faster on eager mode, as it has a
this implementation has a lower foot-print.

If in-place operations are not allowed even in non-grad context, I still
believe that doing ones + tril rather than a ones + tril + zeros + cat
should be faster simply due to the number of memory reads/writes.

I tested that this code produces the same results for `0 <= qlen,mlen <
10` and `same_length in (True, False)`.
  • Loading branch information
lezcano authored and sheonhan committed May 15, 2023
1 parent 396e519 commit 6602192
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions src/transformers/models/xlnet/modeling_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,16 +976,15 @@ def create_mask(self, qlen, mlen):
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask = torch.ones([qlen, qlen])
mask_up = torch.triu(attn_mask, diagonal=1)
attn_mask_pad = torch.zeros([qlen, mlen])
ret = torch.cat([attn_mask_pad, mask_up], dim=1)
mask = torch.ones(qlen, qlen + mlen, self.device)
if self.same_length:
mask_lo = torch.tril(attn_mask, diagonal=-1)
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
mask_lo = mask[:, :qlen].tril(-1)
mask.triu_(mlen + 1)
mask[:, :qlen] += mask_lo
else:
mask.triu_(mlen + 1)

ret = ret.to(self.device)
return ret
return mask

def cache_mem(self, curr_out, prev_mem):
# cache hidden states into memory.
Expand Down

0 comments on commit 6602192

Please sign in to comment.