Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Compute the mask in-place, with less memory reads, and on CUDA on `XL…
…NetLMHeadModel` (huggingface#23332) When working on TorchInductor, I realised that there was a part from `XLNetLMHeadModel` that was being compiled to CPU code. This PR should allow to fuse this operation with other CUDA operations in `torch.compile`. It also should be faster on eager mode, as it has a this implementation has a lower foot-print. If in-place operations are not allowed even in non-grad context, I still believe that doing ones + tril rather than a ones + tril + zeros + cat should be faster simply due to the number of memory reads/writes. I tested that this code produces the same results for `0 <= qlen,mlen < 10` and `same_length in (True, False)`.
- Loading branch information