Skip to content

Commit

Permalink
Update classifier.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fffffgggg54 committed Jul 31, 2024
1 parent 8e332b3 commit 8aedf0d
Showing 1 changed file with 28 additions and 13 deletions.
41 changes: 28 additions & 13 deletions timm/layers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from typing import Optional, Union, Callable

import torch
from torch.jit import Final
import torch.nn as nn
from torch.nn import functional as F

from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .config import use_fused_attn
from .create_act import get_act_layer, create_act_layer
from .create_norm import get_norm_layer, create_norm_layer
from .drop import DropPath
Expand Down Expand Up @@ -275,6 +277,7 @@ def forward(self, x):
return self.head(x)

class Attention(nn.Module):
fused_attn: Final[bool]
def __init__(
self,
dim,
Expand All @@ -290,6 +293,7 @@ def __init__(
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
Expand All @@ -300,12 +304,17 @@ def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)


attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v

if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v

x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
Expand Down Expand Up @@ -353,7 +362,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x

class CrossAttention(nn.Module):

fused_attn: Final[bool]
def __init__(
self,
dim: int,
Expand All @@ -375,6 +384,7 @@ def __init__(
self.dim = dim
self.query_dim = self.dim if query_dim is None else query_dim
self.kv_dim = self.dim if kv_dim is None else kv_dim
self.fused_attn = use_fused_attn()

self.q = nn.Linear(self.query_dim, self.dim, bias=qkv_bias)
self.kv = nn.Linear(self.kv_dim, self.dim * 2, bias=qkv_bias)
Expand All @@ -392,12 +402,17 @@ def forward(self, q, x) -> torch.Tensor:
k, v = kv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)


q = q * self.scale
attn = q @ k.transpose(-2, -1) # [B, n_h, N_q, N_kv]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # [B, n_h, N_q, d_h]
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1) # [B, n_h, N_q, N_kv]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # [B, n_h, N_q, d_h]

x = x.permute(0, 2, 1, 3).reshape(B, N_q, self.dim)
x = self.proj(x)
Expand Down

0 comments on commit 8aedf0d

Please sign in to comment.