From c8c6aca364ddfd834187e1aaed25b8e6437282eb Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 7 Jul 2023 07:55:09 +0100 Subject: [PATCH] Implement Gemm for CudaBlas. --- src/cublas/safe.rs | 124 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/src/cublas/safe.rs b/src/cublas/safe.rs index 5c6cbc76..d4243984 100644 --- a/src/cublas/safe.rs +++ b/src/cublas/safe.rs @@ -265,6 +265,84 @@ impl Gemm for CudaBlas { } } +#[cfg(feature = "f16")] +impl Gemm for CudaBlas { + unsafe fn gemm< + A: DevicePtr, + B: DevicePtr, + C: DevicePtrMut, + >( + &self, + cfg: GemmConfig, + a: &A, + b: &B, + c: &mut C, + ) -> Result<(), CublasError> { + let alpha: f32 = cfg.alpha.to_f32(); + let beta: f32 = cfg.beta.to_f32(); + result::gemm_ex( + self.handle, + cfg.transa, + cfg.transb, + cfg.m, + cfg.n, + cfg.k, + (&alpha) as *const f32 as *const _, + *a.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_16BF, + cfg.lda, + *b.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_16BF, + cfg.ldb, + (&beta) as *const f32 as *const _, + *c.device_ptr_mut() as *mut _, + sys::cudaDataType_t::CUDA_R_16BF, + cfg.ldc, + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, + sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT, + ) + } + unsafe fn gemm_strided_batched< + A: DevicePtr, + B: DevicePtr, + C: DevicePtrMut, + >( + &self, + cfg: StridedBatchedConfig, + a: &A, + b: &B, + c: &mut C, + ) -> Result<(), CublasError> { + let alpha: f32 = cfg.gemm.alpha.to_f32(); + let beta: f32 = cfg.gemm.beta.to_f32(); + result::gemm_strided_batched_ex( + self.handle, + cfg.gemm.transa, + cfg.gemm.transb, + cfg.gemm.m, + cfg.gemm.n, + cfg.gemm.k, + (&alpha) as *const f32 as *const _, + *a.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_16BF, + cfg.gemm.lda, + cfg.stride_a, + *b.device_ptr() as *const _, + sys::cudaDataType_t::CUDA_R_16BF, + cfg.gemm.ldb, + cfg.stride_b, + (&beta) as *const f32 as *const _, + *c.device_ptr_mut() as *mut _, + sys::cudaDataType_t::CUDA_R_16BF, + cfg.gemm.ldc, + cfg.stride_c, + cfg.batch_size, + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F, + sys::cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT, + ) + } +} + impl Gemm for CudaBlas { unsafe fn gemm, B: DevicePtr, C: DevicePtrMut>( &self, @@ -607,6 +685,52 @@ mod tests { ); } } + + #[rustfmt::skip] + let a_dev = dev.htod_sync_copy::(&[ + -0.5944882, 1.8055636, 0.52204555, -0.00397902, + -0.38346434, -0.38013917, 0.4198623, -0.22479166, + -1.6661372, -0.4568837, -0.9043474, 0.39125723, + ].map(half::bf16::from_f32)).unwrap(); + #[rustfmt::skip] + let b_dev = dev.htod_sync_copy::(&[ + 1.1292169, -0.13450263, 0.62789696, -0.5685516, 0.21946938, + 1.0585804, -0.39789402, 0.90205914, 0.989318, -0.3443096, + 1.3412506, 0.3059701, -0.9714474, -0.36113533, -1.6809629, + 3.4746711, -1.0930681, 0.16502666, -0.59988785, 0.41375792, + ].map(half::bf16::from_f32)).unwrap(); + let mut c_dev = dev.alloc_zeros::(M * N).unwrap(); + unsafe { + blas.gemm( + GemmConfig { + transa: sys::cublasOperation_t::CUBLAS_OP_N, + transb: sys::cublasOperation_t::CUBLAS_OP_N, + m: N as i32, + n: M as i32, + k: K as i32, + alpha: half::bf16::from_f32(1.0), + lda: N as i32, + ldb: K as i32, + beta: half::bf16::from_f32(0.0), + ldc: N as i32, + }, + &b_dev, + &a_dev, + &mut c_dev, + ) + } + .unwrap(); + let c_host = dev.sync_reclaim(c_dev).unwrap(); + for m in 0..M { + for n in 0..N { + let found = c_host[m * N + n]; + let expected = c[m][n]; + assert!( + (half::bf16::to_f32(found) - half::f16::to_f32(expected)) <= 1e-2, + "found={found:?}, expected={expected:?}" + ); + } + } } #[test]