Skip to content

Commit

Permalink
Refactor ms_deform_attn compile logic
Browse files Browse the repository at this point in the history
  • Loading branch information
xiuqhou committed May 16, 2024
1 parent 1b5d6f8 commit 27116b9
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions models/bricks/ms_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
from torch.nn.init import constant_, xavier_uniform_
from torch.utils.cpp_extension import load

_C = None
if torch.cuda.is_available():
_C = load(
"MultiScaleDeformableAttention",
sources=[f"{os.path.dirname(__file__)}/ops/cuda/ms_deform_attn_cuda.cu"],
extra_cflags=["-O2"],
verbose=True,
)
try:
_C = load(
"MultiScaleDeformableAttention",
sources=[f"{os.path.dirname(__file__)}/ops/cuda/ms_deform_attn_cuda.cu"],
extra_cflags=["-O2"],
verbose=True,
)
except Exception as e:
warnings.warn(f"Failed to load MultiScaleDeformableAttention C++ extension: {e}")
else:
warnings.warn("No cuda is available, skip loading MultiScaleDeformableAttention C++ extention")

Expand Down Expand Up @@ -351,7 +355,7 @@ def forward(
)

# the original impl for fp32 training
if torch.cuda.is_available() and value.is_cuda:
if _C is not None and value.is_cuda:
output = MultiScaleDeformableAttnFunction.apply(
value.to(torch.float32),
spatial_shapes,
Expand Down

0 comments on commit 27116b9

Please sign in to comment.