Skip to content

Commit ce7477f

Browse files
committed
Update the test checking for cooperative kernels in conditional nodes.
Now we conditionally xfail only when a cuda driver version less than 12.6 is installed. CUDA 12.6 fixes this issue. Before it, cooperative kernels could not be used within the body of a conditional node. We also provide a better error message for users to know that the fix is to upgrade to CUDA 12.6. Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
1 parent e400b6d commit ce7477f

File tree

5 files changed

+55
-14
lines changed

5 files changed

+55
-14
lines changed

nemo/collections/asr/parts/submodules/cuda_graph_rnnt_greedy_decoding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from nemo.collections.asr.parts.utils import rnnt_utils
2828
from nemo.core.utils.cuda_python_utils import (
2929
check_cuda_python_cuda_graphs_conditional_nodes_supported,
30+
checked_graph,
3031
cu_call,
3132
run_nvrtc,
3233
with_conditional_node,
@@ -174,7 +175,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
174175
with (
175176
torch.cuda.stream(stream_for_graph),
176177
torch.inference_mode(),
177-
torch.cuda.graph(self.graph, stream=stream_for_graph, capture_error_mode="thread_local"),
178+
checked_graph(self.graph, stream=stream_for_graph, capture_error_mode="thread_local"),
178179
):
179180
# This is failing...
180181
self.f = torch.zeros(

nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin
2525
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
2626
from nemo.core.utils.cuda_python_utils import (
27+
checked_graph,
2728
check_cuda_python_cuda_graphs_conditional_nodes_supported,
2829
cu_call,
2930
run_nvrtc,
@@ -630,7 +631,7 @@ def _partial_graphs_compile(self):
630631
with (
631632
torch.cuda.stream(stream_for_graph),
632633
torch.inference_mode(),
633-
torch.cuda.graph(
634+
checked_graph(
634635
self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"
635636
),
636637
):
@@ -639,7 +640,7 @@ def _partial_graphs_compile(self):
639640
with (
640641
torch.cuda.stream(stream_for_graph),
641642
torch.inference_mode(),
642-
torch.cuda.graph(
643+
checked_graph(
643644
self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
644645
),
645646
):
@@ -649,7 +650,7 @@ def _partial_graphs_compile(self):
649650
with (
650651
torch.cuda.stream(stream_for_graph),
651652
torch.inference_mode(),
652-
torch.cuda.graph(
653+
checked_graph(
653654
self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"
654655
),
655656
):
@@ -658,7 +659,7 @@ def _partial_graphs_compile(self):
658659
with (
659660
torch.cuda.stream(stream_for_graph),
660661
torch.inference_mode(),
661-
torch.cuda.graph(
662+
checked_graph(
662663
self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
663664
),
664665
):
@@ -672,7 +673,7 @@ def _full_graph_compile(self):
672673
with (
673674
torch.cuda.stream(stream_for_graph),
674675
torch.inference_mode(),
675-
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
676+
checked_graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
676677
):
677678
self._before_outer_loop()
678679

nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin
2626
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
2727
from nemo.core.utils.cuda_python_utils import (
28+
checked_graph,
2829
check_cuda_python_cuda_graphs_conditional_nodes_supported,
2930
cu_call,
3031
run_nvrtc,
@@ -691,7 +692,7 @@ def _partial_graphs_compile(self):
691692
with (
692693
torch.cuda.stream(stream_for_graph),
693694
torch.inference_mode(),
694-
torch.cuda.graph(
695+
checked_graph(
695696
self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"
696697
),
697698
):
@@ -700,7 +701,7 @@ def _partial_graphs_compile(self):
700701
with (
701702
torch.cuda.stream(stream_for_graph),
702703
torch.inference_mode(),
703-
torch.cuda.graph(
704+
checked_graph(
704705
self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
705706
),
706707
):
@@ -710,7 +711,7 @@ def _partial_graphs_compile(self):
710711
with (
711712
torch.cuda.stream(stream_for_graph),
712713
torch.inference_mode(),
713-
torch.cuda.graph(
714+
checked_graph(
714715
self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"
715716
),
716717
):
@@ -719,7 +720,7 @@ def _partial_graphs_compile(self):
719720
with (
720721
torch.cuda.stream(stream_for_graph),
721722
torch.inference_mode(),
722-
torch.cuda.graph(
723+
checked_graph(
723724
self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
724725
),
725726
):
@@ -734,7 +735,7 @@ def _full_graph_compile(self):
734735
with (
735736
torch.cuda.stream(stream_for_graph),
736737
torch.inference_mode(),
737-
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
738+
checked_graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
738739
):
739740
self._before_outer_loop()
740741

nemo/core/utils/cuda_python_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,27 @@ def cu_call(f_call_out):
9595
return tuple(others)
9696

9797

98+
def cuda_python_conditional_node_cooperative_kernels_supported():
99+
"""
100+
Returns true if cuda-python is installed and CUDA driver 12.6 or newer is
101+
installed. Before this CUDA driver version, cooperative nodes could not run
102+
within cuda graph conditional nodes.
103+
"""
104+
try:
105+
check_cuda_python_cuda_graphs_conditional_nodes_supported()
106+
except:
107+
return False
108+
else:
109+
from cuda import cuda
110+
111+
error, driver_version = cuda.cuDriverGetVersion()
112+
if error != cuda.CUresult.CUDA_SUCCESS:
113+
raise ImportError(f"cuDriverGetVersion() returned {cuda.cuGetErrorString(error)}")
114+
driver_version_major = driver_version // 1000
115+
driver_version_minor = (driver_version % 1000) // 10
116+
driver_version = (driver_version_major, driver_version_minor)
117+
return driver_version >= (12,6)
118+
98119
@contextlib.contextmanager
99120
def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, device):
100121
"""
@@ -219,3 +240,16 @@ def run_nvrtc(kernel_string: str, kernel_name: bytes, program_name: bytes):
219240
assert_drv(err)
220241

221242
return kernel
243+
244+
@contextlib.contextmanager
245+
def checked_graph(*args, **kwargs):
246+
"""
247+
Wrapper around torch.cuda.graph that checks for common errors that are too vague for an end user to diagnose based on the error message.
248+
"""
249+
try:
250+
with torch.cuda.graph(*args, **kwargs):
251+
yield
252+
except RuntimeError as err:
253+
if "CUDA error: invalid argument" in str(err):
254+
raise RuntimeError("CUDA Graph capture failed. It is likely that you are calling a cooperative kernel in your RNN-T or TDT prediction network. Cooperative kernels are not allowed inside the bodies of CUDA Graph conditional nodes until CUDA 12.6. Please update to CUDA 12.6. File an issue if that still does not work.") from err
255+
raise

tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
from omegaconf import open_dict
2121

2222
from nemo.collections.asr.models import ASRModel
23-
from nemo.core.utils.cuda_python_utils import skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported
23+
from nemo.core.utils.cuda_python_utils import (
24+
skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported,
25+
cuda_python_conditional_node_cooperative_kernels_supported
26+
)
2427

2528

2629
@pytest.fixture(scope="module")
@@ -53,8 +56,9 @@ def stt_en_fastconformer_transducer_large():
5356
8,
5457
True,
5558
marks=pytest.mark.xfail(
56-
reason="""Cannot instantiate the
57-
body cuda graph of a conditional node with a persistent kernel (in this case,
59+
not cuda_python_conditional_node_cooperative_kernels_supported(),
60+
reason="""Cannot instantiate the
61+
body cuda graph of a conditional node with a persistent kernel (in this case,
5862
a persistent LSTM), which is triggered in cudnn by using a batch size of 8."""
5963
),
6064
),

0 commit comments

Comments
 (0)