-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update the test checking for cooperative kernels in conditional nodes. #9869
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -95,6 +95,28 @@ def cu_call(f_call_out): | |||||
return tuple(others) | ||||||
|
||||||
|
||||||
def cuda_python_conditional_node_cooperative_kernels_supported(): | ||||||
""" | ||||||
Returns true if cuda-python is installed and CUDA driver 12.6 or newer is | ||||||
installed. Before this CUDA driver version, cooperative nodes could not run | ||||||
within cuda graph conditional nodes. | ||||||
""" | ||||||
try: | ||||||
check_cuda_python_cuda_graphs_conditional_nodes_supported() | ||||||
except: | ||||||
|
||||||
return False | ||||||
else: | ||||||
from cuda import cuda | ||||||
|
||||||
error, driver_version = cuda.cuDriverGetVersion() | ||||||
if error != cuda.CUresult.CUDA_SUCCESS: | ||||||
raise ImportError(f"cuDriverGetVersion() returned {cuda.cuGetErrorString(error)}") | ||||||
driver_version_major = driver_version // 1000 | ||||||
driver_version_minor = (driver_version % 1000) // 10 | ||||||
driver_version = (driver_version_major, driver_version_minor) | ||||||
return driver_version >= (12, 6) | ||||||
|
||||||
|
||||||
@contextlib.contextmanager | ||||||
def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, device): | ||||||
""" | ||||||
|
@@ -219,3 +241,19 @@ def run_nvrtc(kernel_string: str, kernel_name: bytes, program_name: bytes): | |||||
assert_drv(err) | ||||||
|
||||||
return kernel | ||||||
|
||||||
|
||||||
@contextlib.contextmanager | ||||||
def checked_graph(*args, **kwargs): | ||||||
""" | ||||||
Wrapper around torch.cuda.graph that checks for common errors that are too vague for an end user to diagnose based on the error message. | ||||||
""" | ||||||
try: | ||||||
with torch.cuda.graph(*args, **kwargs): | ||||||
yield | ||||||
except RuntimeError as err: | ||||||
if "CUDA error: invalid argument" in str(err): | ||||||
raise RuntimeError( | ||||||
"CUDA Graph capture failed. It is likely that you are calling a cooperative kernel in your RNN-T or TDT prediction network. Cooperative kernels are not allowed inside the bodies of CUDA Graph conditional nodes until CUDA 12.6. Please update to CUDA 12.6. File an issue if that still does not work." | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We only support CUDA 12.5 with published pytorch containers see here: https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html. Make sure to support running without cuda_graphs decoding by default and for later version we can make cuda_graphs on by default when containers support it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update NeMo/nemo/core/utils/cuda_python_utils.py Line 21 in 8880c37
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I think @galv wants to preserve I would suggest in this case applying fallback behavior like this
(also, the same code in tdt_loop_labels_computer.py )
if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH:
try:
self._full_graph_compile()
except RuntimeError:
# fallback to graphs without while loops
self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS
self._partial_graphs_compile() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's an interesting possibility. One of the big challenges is that the error returned by torch is not very precise. It's just a RuntimeError corresponding to "invalid argument", or cudaErrorInvalidValue, which is not a precise enough error for us to tell that the problem specifically is that the code is using a cooperative kernel within a conditional node's body graph. And unfortunately we cannot check whether this is the case because conditional node API does not expose a way to get the body graph(s) of a conditional node, right now... Anyway, I suppose if the error was not because of a cooperative kernel, but because of something else, then there is a good chance the error will get thrown by the partial graphs implementation. But it's still not a guarantee! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO sounds like it's worth a shot to be able to move forward here |
||||||
) from err | ||||||
raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need
checked_graph
for partial graphs (without conditional nodes)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very good point. No, we don't.
But I am considering doing your suggestion anyway for a fallback, in which case we wouldn't use checked_graph.