Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the test checking for cooperative kernels in conditional nodes. #9869

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.core.utils.cuda_python_utils import (
check_cuda_python_cuda_graphs_conditional_nodes_supported,
checked_graph,
cu_call,
run_nvrtc,
with_conditional_node,
Expand Down Expand Up @@ -174,7 +175,7 @@ def _reinitialize(self, max_time, batch_size, encoder_output, encoder_output_len
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.graph, stream=stream_for_graph, capture_error_mode="thread_local"),
checked_graph(self.graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
# This is failing...
self.f = torch.zeros(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
from nemo.core.utils.cuda_python_utils import (
check_cuda_python_cuda_graphs_conditional_nodes_supported,
checked_graph,
cu_call,
run_nvrtc,
with_conditional_node,
Expand Down Expand Up @@ -630,7 +631,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need checked_graph for partial graphs (without conditional nodes)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very good point. No, we don't.

But I am considering doing your suggestion anyway for a fallback, in which case we wouldn't use checked_graph.

self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -639,7 +640,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -649,7 +650,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -658,7 +659,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -672,7 +673,7 @@ def _full_graph_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
checked_graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_outer_loop()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs
from nemo.core.utils.cuda_python_utils import (
check_cuda_python_cuda_graphs_conditional_nodes_supported,
checked_graph,
cu_call,
run_nvrtc,
with_conditional_node,
Expand Down Expand Up @@ -691,7 +692,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -700,7 +701,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -710,7 +711,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -719,7 +720,7 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(
checked_graph(
self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
Expand All @@ -734,7 +735,7 @@ def _full_graph_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
checked_graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"),
):
self._before_outer_loop()

Expand Down
38 changes: 38 additions & 0 deletions nemo/core/utils/cuda_python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,28 @@ def cu_call(f_call_out):
return tuple(others)


def cuda_python_conditional_node_cooperative_kernels_supported():
"""
Returns true if cuda-python is installed and CUDA driver 12.6 or newer is
installed. Before this CUDA driver version, cooperative nodes could not run
within cuda graph conditional nodes.
"""
try:
check_cuda_python_cuda_graphs_conditional_nodes_supported()
except:
return False
else:
from cuda import cuda

error, driver_version = cuda.cuDriverGetVersion()
if error != cuda.CUresult.CUDA_SUCCESS:
raise ImportError(f"cuDriverGetVersion() returned {cuda.cuGetErrorString(error)}")
driver_version_major = driver_version // 1000
driver_version_minor = (driver_version % 1000) // 10
driver_version = (driver_version_major, driver_version_minor)
return driver_version >= (12, 6)


@contextlib.contextmanager
def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, device):
"""
Expand Down Expand Up @@ -219,3 +241,19 @@ def run_nvrtc(kernel_string: str, kernel_name: bytes, program_name: bytes):
assert_drv(err)

return kernel


@contextlib.contextmanager
def checked_graph(*args, **kwargs):
"""
Wrapper around torch.cuda.graph that checks for common errors that are too vague for an end user to diagnose based on the error message.
"""
try:
with torch.cuda.graph(*args, **kwargs):
yield
except RuntimeError as err:
if "CUDA error: invalid argument" in str(err):
raise RuntimeError(
"CUDA Graph capture failed. It is likely that you are calling a cooperative kernel in your RNN-T or TDT prediction network. Cooperative kernels are not allowed inside the bodies of CUDA Graph conditional nodes until CUDA 12.6. Please update to CUDA 12.6. File an issue if that still does not work."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only support CUDA 12.5 with published pytorch containers see here: https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html.

Make sure to support running without cuda_graphs decoding by default and for later version we can make cuda_graphs on by default when containers support it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update

__CUDA_PYTHON_MINIMUM_VERSION_CUDA_GRAPH_CONDITIONAL_NODES_SUPPORTED__ = (12, 3) # 12030
the minimum version to 12.6

Copy link
Collaborator

@artbataev artbataev Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think @galv wants to preserve 12.3 as a requirement, but fail in rare cases when cooperative kernels are selected.

I would suggest in this case applying fallback behavior like this

if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH:

(also, the same code in tdt_loop_labels_computer.py)

if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH:
    try:
        self._full_graph_compile()
    except RuntimeError:
       # fallback to graphs without while loops
       self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS
       self._partial_graphs_compile()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's an interesting possibility.

One of the big challenges is that the error returned by torch is not very precise. It's just a RuntimeError corresponding to "invalid argument", or cudaErrorInvalidValue, which is not a precise enough error for us to tell that the problem specifically is that the code is using a cooperative kernel within a conditional node's body graph. And unfortunately we cannot check whether this is the case because conditional node API does not expose a way to get the body graph(s) of a conditional node, right now...

Anyway, I suppose if the error was not because of a cooperative kernel, but because of something else, then there is a good chance the error will get thrown by the partial graphs implementation. But it's still not a guarantee!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO sounds like it's worth a shot to be able to move forward here

) from err
raise
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from omegaconf import open_dict

from nemo.collections.asr.models import ASRModel
from nemo.core.utils.cuda_python_utils import skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported
from nemo.core.utils.cuda_python_utils import (
cuda_python_conditional_node_cooperative_kernels_supported,
skip_cuda_python_test_if_cuda_graphs_conditional_nodes_not_supported,
)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -53,9 +56,10 @@ def stt_en_fastconformer_transducer_large():
8,
True,
marks=pytest.mark.xfail(
reason="""Cannot instantiate the
body cuda graph of a conditional node with a persistent kernel (in this case,
a persistent LSTM), which is triggered in cudnn by using a batch size of 8."""
not cuda_python_conditional_node_cooperative_kernels_supported(),
reason="""Cannot instantiate the
body cuda graph of a conditional node with a persistent kernel (in this case,
a persistent LSTM), which is triggered in cudnn by using a batch size of 8.""",
),
),
],
Expand Down
Loading