Skip to content

Commit

Permalink
Add safeguards for CUDA kernel load in Deformable DETR (#19037)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored Sep 14, 2022
1 parent 31be02f commit 0e24548
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,21 @@
)
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ...utils import is_ninja_available, logging
from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels


logger = logging.get_logger(__name__)

# Move this to not compile only when importing, this needs to happen later, like in __init__.
if is_torch_cuda_available():
if is_torch_cuda_available() and is_ninja_available():
logger.info("Loading custom CUDA kernels...")
MultiScaleDeformableAttention = load_cuda_kernels()
try:
MultiScaleDeformableAttention = load_cuda_kernels()
except Exception as e:
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
MultiScaleDeformableAttention = None
else:
MultiScaleDeformableAttention = None

Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
is_in_notebook,
is_ipex_available,
is_librosa_available,
is_ninja_available,
is_onnx_available,
is_pandas_available,
is_phonemizer_available,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ def is_apex_available():
return importlib.util.find_spec("apex") is not None


def is_ninja_available():
return importlib.util.find_spec("ninja") is not None


def is_ipex_available():
def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
Expand Down

0 comments on commit 0e24548

Please sign in to comment.