Skip to content

Commit

Permalink
[SDPA] Make sure attn mask creation is always done on CPU (huggingfac…
Browse files Browse the repository at this point in the history
…e#28400)

* [SDPA] Make sure attn mask creation is always done on CPU

* Update docker to 2.1.1

* revert test change
  • Loading branch information
patrickvonplaten authored Jan 9, 2024
1 parent 5c7e11e commit 8604dd3
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ SHELL ["sh", "-lc"]
# The following `ARG` are mainly used to specify the versions explicitly & directly in this docker file, and not meant
# to be used as arguments for docker build (so far).

ARG PYTORCH='2.1.0'
ARG PYTORCH='2.1.1'
# (not always a valid torch version)
ARG INTEL_TORCH_EXT='2.1.0'
ARG INTEL_TORCH_EXT='2.1.1'
# Example: `cu102`, `cu113`, etc.
ARG CUDA='cu118'

Expand Down
2 changes: 1 addition & 1 deletion docker/transformers-pytorch-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ARG REF=main
RUN git clone https://github.com/huggingface/transformers && cd transformers && git checkout $REF

# If set to nothing, will install the latest version
ARG PYTORCH='2.1.0'
ARG PYTORCH='2.1.1'
ARG TORCH_VISION=''
ARG TORCH_AUDIO=''
# Example: `cu102`, `cu113`, etc.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def _unmask_unattended(

# Get the index of the first non-zero value for every sample in the batch.
# In the above example, indices = [[2], [0], [1]]]
tmp = torch.arange(attention_mask.shape[1], 0, -1, device=attention_mask.device)
indices = torch.argmax(attention_mask * tmp, 1, keepdim=True)
tmp = torch.arange(attention_mask.shape[1], 0, -1)
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)

# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
# expanded mask will be completely unattended.
Expand Down

0 comments on commit 8604dd3

Please sign in to comment.