Skip to content

Commit

Permalink
Add an attribute to disable custom kernels in deformable detr in orde…
Browse files Browse the repository at this point in the history
…r to make the model ONNX exportable (#22918)

* add disable kernel option

* add comment

* fix copies

* add disable_custom_kernels to config

* Update src/transformers/models/deta/modeling_deta.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/deta/modeling_deta.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/deta/modeling_deta.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* style

* fix

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
fxmarty and amyeroberts authored Apr 24, 2023
1 parent 84097f6 commit edb6d95
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ class DeformableDetrConfig(PretrainedConfig):
based on the predictions from the previous layer.
focal_alpha (`float`, *optional*, defaults to 0.25):
Alpha parameter in the focal loss.
disable_custom_kernels (`bool`, *optional*, defaults to `False`):
Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom
kernels are not supported by PyTorch ONNX export.
Examples:
Expand Down Expand Up @@ -189,6 +192,7 @@ def __init__(
giou_loss_coefficient=2,
eos_coefficient=0.1,
focal_alpha=0.25,
disable_custom_kernels=False,
**kwargs,
):
if backbone_config is not None and use_timm_backbone:
Expand Down Expand Up @@ -246,6 +250,7 @@ def __init__(
self.giou_loss_coefficient = giou_loss_coefficient
self.eos_coefficient = eos_coefficient
self.focal_alpha = focal_alpha
self.disable_custom_kernels = disable_custom_kernels
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

@property
Expand Down
57 changes: 30 additions & 27 deletions src/transformers/models/deformable_detr/modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,13 +589,13 @@ class DeformableDetrMultiscaleDeformableAttention(nn.Module):
Multiscale deformable attention as proposed in Deformable DETR.
"""

def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
def __init__(self, config: DeformableDetrConfig, num_heads: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
if config.d_model % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}"
f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
)
dim_per_head = embed_dim // num_heads
dim_per_head = config.d_model // num_heads
# check if dim_per_head is power of 2
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
warnings.warn(
Expand All @@ -606,15 +606,17 @@ def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int)

self.im2col_step = 64

self.d_model = embed_dim
self.n_levels = n_levels
self.d_model = config.d_model
self.n_levels = config.num_feature_levels
self.n_heads = num_heads
self.n_points = n_points

self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
self.value_proj = nn.Linear(config.d_model, config.d_model)
self.output_proj = nn.Linear(config.d_model, config.d_model)

self.disable_custom_kernels = config.disable_custom_kernels

self._reset_parameters()

Expand Down Expand Up @@ -692,19 +694,24 @@ def forward(
)
else:
raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
try:
# custom kernel
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:

if self.disable_custom_kernels:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
else:
try:
# custom kernel
output = MultiScaleDeformableAttentionFunction.apply(
value,
spatial_shapes,
level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
except Exception:
# PyTorch implementation
output = multi_scale_deformable_attention(value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)

return output, attention_weights
Expand Down Expand Up @@ -832,10 +839,7 @@ def __init__(self, config: DeformableDetrConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = DeformableDetrMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.encoder_n_points,
config, num_heads=config.encoder_attention_heads, n_points=config.encoder_n_points
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
Expand Down Expand Up @@ -933,9 +937,8 @@ def __init__(self, config: DeformableDetrConfig):
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
# cross-attention
self.encoder_attn = DeformableDetrMultiscaleDeformableAttention(
embed_dim=self.embed_dim,
config,
num_heads=config.decoder_attention_heads,
n_levels=config.num_feature_levels,
n_points=config.decoder_n_points,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,6 @@ class DetaMultiscaleDeformableAttention(nn.Module):
Multiscale deformable attention as proposed in Deformable DETR.
"""

# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention.__init__ with DeformableDetr->Deta
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
Expand Down Expand Up @@ -721,7 +720,6 @@ def forward(
return attn_output, attn_weights_reshaped


# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrEncoderLayer with DeformableDetr->Deta
class DetaEncoderLayer(nn.Module):
def __init__(self, config: DetaConfig):
super().__init__()
Expand Down Expand Up @@ -810,7 +808,6 @@ def forward(
return outputs


# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrDecoderLayer with DeformableDetr->Deta
class DetaDecoderLayer(nn.Module):
def __init__(self, config: DetaConfig):
super().__init__()
Expand Down

0 comments on commit edb6d95

Please sign in to comment.