Skip to content

Commit

Permalink
Don't use named args in MHA calls to allow applying pytorch forward h…
Browse files Browse the repository at this point in the history
…ooks to VIT (#6956)
  • Loading branch information
sovrasov committed Nov 18, 2022
1 parent d710f3d commit 5b4f79d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchvision/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x)
x = x + input

Expand Down

0 comments on commit 5b4f79d

Please sign in to comment.