The rust-cuda
library aims to provide a Rustic interface for CUDA programming, enabling Rust developers to write and execute CUDA kernels directly from Rust.
This was created with the intention of being used specifically to create custom kernels to optimize a Mixture of Experts (MoE) layer in large language models (LLMs).
Mocking framework mockall
is used to define mock implementations for cuda and nvrtc in mock.rs files.
This can be used with the feature flag "mock" to run tests without needing a GPU or drivers:
cargo test --features mock
Without the feature flag, the 'real' implementations in wrapper.rs (for cuda) and compiler.rs (for nvrtc) will be used
- Initialization:
- start by selecting a CUDA device and allocating necessary memory on the device using
device
andmemory
modules
- start by selecting a CUDA device and allocating necessary memory on the device using
- Kernel Development:
- instead of writing CUDA kernels in C/C++, write kernel logic directly in Rust
- the library provides mechanisms to compile Rust code into CUDA-compatible kernels (through NVRTC)
- Kernel Execution:
- with kernels compiled, execute them by passing inputs (data, model parameters) from Rust to the GPU, perform the computation, and then retrieve the results back into Rust
- Optimization Loop:
- iteratively refine kernels and execution parameters based on performance measurements, all within Rust's ecosystem
- Integration into Larger Models:
- the kernels can be integrated
The library is organized into several modules, each responsible for different aspects of CUDA programming:
src/lib.rs
: The main entry point of the library, tying together various modules and functionalitiessrc/memory.rs
: Handles GPU memory allocation, deallocation, and data transferssrc/device.rs
: Manages CUDA device selection and queries device propertiessrc/kernel.rs
: Focuses on compiling, managing, and executing CUDA kernelssrc/cuda/mod.rs
:Contains low-level bindings and safe wrappers for CUDA API callssrc/cuda/ffi.rs
: Raw FFI bindings to CUDA's C APIssrc/cuda/wrapper.rs
: Safe, idiomatic Rust wrappers around the CUDA FFI bindingssrc/cuda/mock.rs
: Holds mock implementations
src/nvrtc/mod.rs
: Facilitates runtime compilation of CUDA kernels using NVRTCsrc/nvrtc/ffi.rs
: Raw FFI bindings to NVRTC's C APIssrc/nvrtc/compiler.rs
: Higher-level functionality for compiling and managing CUDA code at runtimesrc/nvrtc/mock.rs
: Holds mock implementations
(add instructions on how to build, run tests, and basic usage examples)