|
| 1 | +// SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +//! Module to integration with CUDA |
| 5 | +//! |
| 6 | +//! This module will be a standalong crates, likely called `dynamo-cuda`; however, for the time, it will |
| 7 | +//! life as a submodule of `dynamo-llm`. |
| 8 | +//! |
| 9 | +//! This implementation will include a set of traits for extracting raw `cudarc::driver::sys` objects. |
| 10 | +//! |
| 11 | +//! Dynamo will generally not be the primary compute driver within an application, but a secondary source |
| 12 | +//! of logic that may be used inconjunction with the primary compute driver, e.g. vLLM use of PyTorch is |
| 13 | +//! the primary CUDA context. |
| 14 | +//! |
| 15 | +//! In order for Dynamo to avoid creating its own CUDA context, the following traits are provided so |
| 16 | +//! that we may tap the lower level CUDA context, streams, events, etcs from external sources and leverage |
| 17 | +//! them within Dynamo. |
| 18 | +
|
| 19 | +use cudarc::driver::{ |
| 20 | + sys::{cuCtxPopCurrent_v2, cuCtxPushCurrent_v2, cudaError_enum, CUcontext, CUstream}, |
| 21 | + CudaContext, CudaStream, |
| 22 | +}; |
| 23 | +use std::pin::Pin; |
| 24 | +use std::{marker::PhantomData, sync::Arc}; |
| 25 | + |
| 26 | +pub trait DynamoCudaContextProvider { |
| 27 | + /// # Safety |
| 28 | + /// |
| 29 | + /// This method is unsafe because it directly accesses the underlying CUDA context. |
| 30 | + /// The caller must ensure that the context is valid and that the CUDA context is active. |
| 31 | + unsafe fn cu_context(&self) -> cudarc::driver::sys::CUcontext; |
| 32 | + |
| 33 | + fn bind_to_thread(&self) -> Pin<Box<DynamoCudaContextGuard>> { |
| 34 | + unsafe { DynamoCudaContextGuard::new(self.cu_context()) } |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +pub trait DynamoCudaStreamProvider { |
| 39 | + /// # Safety |
| 40 | + /// |
| 41 | + /// This method is unsafe because it directly accesses the underlying CUDA stream. |
| 42 | + /// The caller must ensure that the stream is valid and that the CUDA context is active. |
| 43 | + /// |
| 44 | + /// Similarly, any pointers/references to data for which the stream will be accessed must |
| 45 | + /// have proper lifetimes and scoping, which is not guaranteed by this trait. |
| 46 | + unsafe fn cu_stream(&self) -> cudarc::driver::sys::CUstream; |
| 47 | + |
| 48 | + fn context(&self) -> Arc<dyn DynamoCudaContextProvider>; |
| 49 | +} |
| 50 | + |
| 51 | +/// A CUDA context guard that ensures safe access to CUDA contexts. |
| 52 | +/// |
| 53 | +/// This guard: |
| 54 | +/// - Cannot be moved (uses PhantomPinned) |
| 55 | +/// - Cannot be cloned |
| 56 | +/// - Cannot pass across async boundaries (!Send + !Sync) |
| 57 | +/// - Provides safe access to the underlying CUDA context |
| 58 | +/// - Automatically manages context lifecycle |
| 59 | +pub struct DynamoCudaContextGuard { |
| 60 | + context: cudarc::driver::sys::CUcontext, |
| 61 | + // Prevent the guard from being moved |
| 62 | + _pin: std::marker::PhantomPinned, |
| 63 | + // Prevent Send + Sync to avoid crossing async boundaries |
| 64 | + _not_send_sync: PhantomData<*const ()>, |
| 65 | +} |
| 66 | + |
| 67 | +impl DynamoCudaContextGuard { |
| 68 | + /// Create a new context guard from a context provider. |
| 69 | + /// |
| 70 | + /// This is a safe constructor that pushes the context onto the CUDA context stack |
| 71 | + /// and ensures it will be properly popped when the guard is dropped. |
| 72 | + /// |
| 73 | + /// # Arguments |
| 74 | + /// * `provider` - A reference to something that can provide a CUDA context |
| 75 | + /// |
| 76 | + /// # Returns |
| 77 | + /// A pinned context guard that manages the CUDA context safely |
| 78 | + /// |
| 79 | + /// # Panics |
| 80 | + /// Panics if the CUDA context push operation fails |
| 81 | + /// # Safety |
| 82 | + /// |
| 83 | + /// This function dereferences a raw pointer and interacts with the CUDA driver API. |
| 84 | + /// The caller must ensure the context is valid. |
| 85 | + pub unsafe fn new(context: CUcontext) -> Pin<Box<Self>> { |
| 86 | + // Push the context onto the CUDA context stack |
| 87 | + let result = cuCtxPushCurrent_v2(context); |
| 88 | + if result != cudaError_enum::CUDA_SUCCESS { |
| 89 | + panic!("Failed to push CUDA context: {:?}", result); |
| 90 | + } |
| 91 | + |
| 92 | + let guard = Self { |
| 93 | + context, |
| 94 | + _pin: std::marker::PhantomPinned, |
| 95 | + _not_send_sync: PhantomData, |
| 96 | + }; |
| 97 | + |
| 98 | + Box::pin(guard) |
| 99 | + } |
| 100 | + |
| 101 | + /// Get the raw CUDA context. |
| 102 | + /// |
| 103 | + /// This method is safe because the guard ensures the context remains valid |
| 104 | + /// for its lifetime and cannot be moved or passed across async boundaries. |
| 105 | + /// |
| 106 | + /// # Returns |
| 107 | + /// The raw CUDA context handle |
| 108 | + pub fn context(&self) -> cudarc::driver::sys::CUcontext { |
| 109 | + self.context |
| 110 | + } |
| 111 | +} |
| 112 | + |
| 113 | +impl Drop for DynamoCudaContextGuard { |
| 114 | + fn drop(&mut self) { |
| 115 | + // Pop the context from the CUDA context stack when the guard is dropped |
| 116 | + let mut popped_context: CUcontext = std::ptr::null_mut(); |
| 117 | + let result = unsafe { cuCtxPopCurrent_v2(&mut popped_context) }; |
| 118 | + |
| 119 | + // Log errors but don't panic in Drop |
| 120 | + if result != cudaError_enum::CUDA_SUCCESS { |
| 121 | + eprintln!("Warning: Failed to pop CUDA context in drop: {:?}", result); |
| 122 | + } |
| 123 | + |
| 124 | + // Verify we popped the expected context |
| 125 | + if popped_context != self.context { |
| 126 | + eprintln!( |
| 127 | + "Warning: Popped context {:?} does not match expected context {:?}", |
| 128 | + popped_context, self.context |
| 129 | + ); |
| 130 | + } |
| 131 | + } |
| 132 | +} |
| 133 | + |
| 134 | +/// A CUDA context provider that wraps an external CUDA context. |
| 135 | +pub struct ExternalCudaContext { |
| 136 | + // SAFETY: CUcontext is thread-safe to pass between threads and can be used concurrently. |
| 137 | + context: CUcontext, |
| 138 | +} |
| 139 | + |
| 140 | +// SAFETY: See notes on CUcontext above. |
| 141 | +unsafe impl Send for ExternalCudaContext {} |
| 142 | +unsafe impl Sync for ExternalCudaContext {} |
| 143 | + |
| 144 | +impl ExternalCudaContext { |
| 145 | + pub fn new(context: CUcontext) -> Arc<Self> { |
| 146 | + Arc::new(Self { context }) |
| 147 | + } |
| 148 | + |
| 149 | + pub fn cu_context(&self) -> CUcontext { |
| 150 | + self.context |
| 151 | + } |
| 152 | +} |
| 153 | + |
| 154 | +impl DynamoCudaContextProvider for ExternalCudaContext { |
| 155 | + unsafe fn cu_context(&self) -> cudarc::driver::sys::CUcontext { |
| 156 | + self.cu_context() |
| 157 | + } |
| 158 | +} |
| 159 | + |
| 160 | +/// A CUDA stream provider that wraps an external CUDA stream. |
| 161 | +pub struct ExternalCudaStream { |
| 162 | + stream: CUstream, |
| 163 | + context: Arc<dyn DynamoCudaContextProvider>, |
| 164 | +} |
| 165 | + |
| 166 | +impl ExternalCudaStream { |
| 167 | + pub fn new(stream: CUstream, context: Arc<dyn DynamoCudaContextProvider>) -> Self { |
| 168 | + Self { stream, context } |
| 169 | + } |
| 170 | +} |
| 171 | + |
| 172 | +impl DynamoCudaStreamProvider for ExternalCudaStream { |
| 173 | + unsafe fn cu_stream(&self) -> cudarc::driver::sys::CUstream { |
| 174 | + self.stream |
| 175 | + } |
| 176 | + |
| 177 | + fn context(&self) -> Arc<dyn DynamoCudaContextProvider> { |
| 178 | + self.context.clone() |
| 179 | + } |
| 180 | +} |
| 181 | + |
| 182 | +// The PhantomData<*const ()> field automatically makes this !Send and !Sync |
| 183 | +// which prevents the guard from crossing async boundaries |
| 184 | + |
| 185 | +// Implementations of this trait for the [`cudarc`] crate. |
| 186 | + |
| 187 | +impl DynamoCudaContextProvider for CudaContext { |
| 188 | + unsafe fn cu_context(&self) -> cudarc::driver::sys::CUcontext { |
| 189 | + self.cu_ctx() |
| 190 | + } |
| 191 | +} |
| 192 | + |
| 193 | +impl DynamoCudaContextProvider for CudaStream { |
| 194 | + unsafe fn cu_context(&self) -> cudarc::driver::sys::CUcontext { |
| 195 | + self.context().cu_context() |
| 196 | + } |
| 197 | +} |
| 198 | + |
| 199 | +impl DynamoCudaStreamProvider for CudaStream { |
| 200 | + unsafe fn cu_stream(&self) -> cudarc::driver::sys::CUstream { |
| 201 | + self.cu_stream() |
| 202 | + } |
| 203 | + |
| 204 | + fn context(&self) -> Arc<dyn DynamoCudaContextProvider> { |
| 205 | + self.context().clone() |
| 206 | + } |
| 207 | +} |
0 commit comments