From ac54e3e7aa427ff6737d9911f1962a9dbaebcd4a Mon Sep 17 00:00:00 2001 From: 123epsilon Date: Sun, 26 Jan 2025 05:30:36 +0000 Subject: [PATCH] llama example working, bmm triton kernel --- .github/workflows/build_zoom_backend.yml | 124 ++++++++++++ aten/src/ATen/native/native_functions.yaml | 36 ++-- aten/src/ATen/native/zoom/Bmm.cpp | 122 ------------ .../native/zoom/DistributionRandomKernel.cu | 27 +++ .../ATen/native/zoom/DistributionUniform.cu | 15 ++ aten/src/ATen/native/zoom/HIPbmm.cu | 132 ------------- aten/src/ATen/native/zoom/TensorCompare.cu | 133 +++++++++++++ test/test_ops.py | 3 +- torch/zoom/__init__.py | 2 +- torch/zoom/zoom_triton_mm.py | 182 ++++++++++++++++++ 10 files changed, 501 insertions(+), 275 deletions(-) create mode 100644 .github/workflows/build_zoom_backend.yml delete mode 100644 aten/src/ATen/native/zoom/Bmm.cpp create mode 100644 aten/src/ATen/native/zoom/DistributionRandomKernel.cu create mode 100644 aten/src/ATen/native/zoom/DistributionUniform.cu delete mode 100644 aten/src/ATen/native/zoom/HIPbmm.cu create mode 100644 aten/src/ATen/native/zoom/TensorCompare.cu create mode 100644 torch/zoom/zoom_triton_mm.py diff --git a/.github/workflows/build_zoom_backend.yml b/.github/workflows/build_zoom_backend.yml new file mode 100644 index 00000000000000..aa7053cafe8379 --- /dev/null +++ b/.github/workflows/build_zoom_backend.yml @@ -0,0 +1,124 @@ +name: "Build PyTorch" + +on: + workflow_dispatch: + inputs: + force_debug_with_tmate: + type: boolean + description: 'Run the build with tmate session' + required: false + default: false + debug_with_tmate: + type: boolean + description: 'Run the build with a tmate session ONLY in case of failure' + required: false + default: false + pull_request: + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + build: + + strategy: + fail-fast: false + matrix: + include: + - name: "ubuntu-22.04" + runs-on: "mi300" + # container: "rocm/pytorch:rocm6.2.3_ubuntu22.04_py3.10_pytorch_release_2.3.0" + # runs-on: "nod-ai-shared-cpubuilder-manylinux-x86_64" + + runs-on: ${{ matrix.runs-on }} + + name: ${{ matrix.name }} + + env: + CACHE_DIR: ${{ github.workspace }}/.container-cache + # either the PR number or `branch-N` where N always increments + CACHE_KEY: linux-build-test-cpp-asserts-manylinux-v2-${{ format('{0}-{1}', github.ref_name, github.run_number) }} + + defaults: + run: + shell: bash + + permissions: + id-token: write + contents: write + + container: + image: ${{ matrix.container }} + + steps: + - name: "Check out repository" + uses: actions/checkout@v4.2.2 + with: + submodules: true + + - name: Enable cache + uses: actions/cache/restore@v3 + with: + path: ${{ env.CACHE_DIR }} + key: ${{ env.CACHE_KEY }} + restore-keys: linux-build-test-cpp- + + - name: "Build PyTorch" + id: build + run: | + + export CCACHE_DIR="${{ env.CACHE_DIR }}" + export CMAKE_C_COMPILER_LAUNCHER=ccache + export CMAKE_CXX_COMPILER_LAUNCHER=ccache + export CCACHE_SLOPPINESS=include_file_ctime,include_file_mtime,time_macros + + python -m venv venv + source venv/bin/activate + pip install -r requirements.txt + ./build.sh + + - name: "Audit" + id: audit + run: | + + sudo apt install patchelf + source venv/bin/activate + pip install auditwheel + auditwheel repair -w dist --plat manylinux_2_39_x86_64 dist/torch* + + - name: Save cache + uses: actions/cache/save@v3 + if: ${{ !cancelled() }} + with: + path: ${{ env.CACHE_DIR }} + key: ${{ env.CACHE_KEY }} + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: ${{ matrix.name }}_artifact + path: dist + if-no-files-found: warn + + - name: Release current commit + uses: ncipollo/release-action@v1.12.0 + with: + artifacts: "dist/torch*.whl" + token: "${{ secrets.GITHUB_TOKEN }}" + tag: "latest" + name: "latest" + removeArtifacts: false + allowUpdates: true + replacesArtifacts: true + makeLatest: true + + - name: "Setup tmate session" + if: ${{ (failure() && inputs.debug_with_tmate) || inputs.force_debug_with_tmate }} + uses: mxschmitt/action-tmate@v3.18 + with: + limit-access-to-actor: true + install-dependencies: ${{ startsWith(matrix.runs-on, 'macos') || startsWith(matrix.runs-on, 'windows') }} \ No newline at end of file diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1664a6642b4cc4..6271e79c453abf 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1352,7 +1352,6 @@ dispatch: CPU: bmm_out_cpu CUDA: bmm_out_cuda - PrivateUse1: bmm_out_zoom MPS: bmm_out_mps SparseCPU: bmm_out_sparse_cpu SparseCUDA: bmm_out_sparse_cuda @@ -1513,7 +1512,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_out + CPU, CUDA, PrivateUse1: clamp_out MPS: clamp_out_mps tags: pointwise @@ -1522,7 +1521,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_Tensor_out + CPU, CUDA, PrivateUse1: clamp_Tensor_out MPS: clamp_Tensor_out_mps tags: pointwise @@ -1553,7 +1552,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_max_out + CPU, CUDA, PrivateUse1: clamp_max_out MPS: clamp_max_out_mps tags: pointwise @@ -1562,7 +1561,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_max_Tensor_out + CPU, CUDA, PrivateUse1: clamp_max_Tensor_out MPS: clamp_max_Tensor_out_mps tags: pointwise @@ -1593,7 +1592,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_min_out + CPU, CUDA, PrivateUse1: clamp_min_out MPS: clamp_min_out_mps tags: pointwise @@ -1602,7 +1601,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_min_Tensor_out + CPU, CUDA, PrivateUse1: clamp_min_Tensor_out MPS: clamp_min_Tensor_out_mps tags: pointwise @@ -3168,7 +3167,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, MPS: isnan + CPU, CUDA, MPS, PrivateUse1: isnan SparseCPU, SparseCUDA: isnan_sparse SparseCsrCPU, SparseCsrCUDA: isnan_sparse_csr autogen: isnan.out @@ -4121,7 +4120,6 @@ dispatch: CPU: mm_out_cpu CUDA: mm_out_cuda - PrivateUse1: mm_out_zoom MPS: mm_out_mps SparseCPU, SparseCUDA: _sparse_mm_out SparseCsrCPU, SparseCsrCUDA: _sparse_csr_mm_out @@ -6463,13 +6461,13 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA, MPS: where + CPU, CUDA, MPS, PrivateUse1: where tags: [core, pointwise] - func: where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA, MPS: where_self_out + CPU, CUDA, MPS, PrivateUse1: where_self_out - func: where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor variants: function @@ -7874,7 +7872,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, Meta, MPS: set_ + CPU, CUDA, Meta, MPS, PrivateUse1: set_ autogen: set.source_Storage, set.source_Storage_out tags: inplace_view @@ -7905,7 +7903,7 @@ device_check: NoCheck device_guard: False dispatch: - CPU, CUDA, Meta, MPS: set_tensor_ + CPU, CUDA, Meta, MPS, PrivateUse1: set_tensor_ autogen: set.source_Tensor, set.source_Tensor_out tags: inplace_view @@ -8663,7 +8661,7 @@ variants: method tags: nondeterministic_seeded dispatch: - CPU, CUDA: random_ + CPU, CUDA, PrivateUse1: random_ Meta: random_meta_ MPS: random_mps_ autogen: random.from, random.from_out @@ -8673,7 +8671,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: random_ + CPU, CUDA, PrivateUse1: random_ Meta: random_meta_ MPS: random_mps_ autogen: random.to, random.to_out @@ -8683,7 +8681,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: random_ + CPU, CUDA, PrivateUse1: random_ MPS: random_mps_ Meta: random_meta_ autogen: random, random.out @@ -8693,7 +8691,7 @@ tags: nondeterministic_seeded variants: method dispatch: - CPU, CUDA: uniform_ + CPU, CUDA, PrivateUse1: uniform_ MPS: uniform_mps_ Meta: uniform_meta_ autogen: uniform, uniform.out @@ -13077,7 +13075,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: isposinf_out + CPU, CUDA, PrivateUse1: isposinf_out SparseCPU, SparseCUDA: isposinf_sparse_out SparseCsrCPU, SparseCsrCUDA: isposinf_sparse_csr_out tags: pointwise @@ -13094,7 +13092,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: isneginf_out + CPU, CUDA, PrivateUse1: isneginf_out SparseCPU, SparseCUDA: isneginf_sparse_out SparseCsrCPU, SparseCsrCUDA: isneginf_sparse_csr_out tags: pointwise diff --git a/aten/src/ATen/native/zoom/Bmm.cpp b/aten/src/ATen/native/zoom/Bmm.cpp deleted file mode 100644 index 53e87a7eb3913e..00000000000000 --- a/aten/src/ATen/native/zoom/Bmm.cpp +++ /dev/null @@ -1,122 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#endif - - -namespace at::native { - // Forward decl, defined in HIPbmm.cu - template - void batched_matmul(const T* A, const T* B, T* C, int M, int N, int K, int batch_size); - - const Tensor& bmm_out_hip_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2) { - // handle pathological cases - if (result.numel() == 0) { - return result; - } else if (batch1.size(2) == 0) { - return result.zero_(); - } - TORCH_CHECK(batch1.sizes()[2] == batch2.sizes()[1], "batch1 dim 2 must match batch2 dim 1"); - - c10::MaybeOwned result_ = c10::MaybeOwned::borrowed(result); - IntArrayRef result_strides = result.strides(); - IntArrayRef result_sizes = result.sizes(); - - int m = batch1.sizes()[1]; - int n = batch1.sizes()[2]; - int k = batch2.sizes()[2]; - int num_batches = result_->sizes()[0]; - - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "bmm_hip", [&] { - const scalar_t* batch1_ptr = batch1.const_data_ptr(); - const scalar_t* batch2_ptr = batch2.const_data_ptr(); - scalar_t* result_ptr = result_->mutable_data_ptr(); - - batched_matmul(batch1_ptr, batch2_ptr, result_ptr, m, n, k, num_batches); - }); - if (!result.is_same(*result_)) { - result.copy_(*result_); - } - return result; - - } - - TORCH_IMPL_FUNC(bmm_out_zoom)(const Tensor& batch1, const Tensor& batch2, const Tensor &result) - { - NoNamesGuard guard; - bmm_out_hip_impl(result, result, batch1, batch2); - } - - Tensor& mm_out_hip_impl(Tensor& result, const Tensor& mat1, const Tensor& mat2) { - // Make sure to keep addmm_hip below in sync with this code; it - // preflights a check to try to avoid actually needing to call - // expand(). - TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ) - - TensorArg targs[]{{result, "out", 0}, {mat1, "mat1", 1}, {mat2, "mat2", 2}}; - checkAllSameGPU(__func__, targs); - - IntArrayRef mat1_sizes = mat1.sizes(); - IntArrayRef mat2_sizes = mat2.sizes(); - at::ScalarType scalar_type = mat1.scalar_type(); - TORCH_CHECK(result.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - - // resize result tensor - at::native::resize_output(result, {mat1_sizes[0], mat2_sizes[1]}); - IntArrayRef result_sizes = result.sizes(); - if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) { - return result; - } - - if (mat1.numel() == 0) { - // By definition, values in self should be ignored. nans and infs - // should not propagate - return result.zero_(); - } - - int m = mat1_sizes[0]; - int n = mat1_sizes[1]; - int k = mat2_sizes[1]; - - // TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result.is_conj()); - - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - scalar_type, - "mm_zoom", - [&] { - const scalar_t* mat1_ptr = mat1.const_data_ptr(); - const scalar_t* mat2_ptr = mat2.const_data_ptr(); - scalar_t* result_ptr = result.mutable_data_ptr(); - batched_matmul(mat1_ptr, mat2_ptr, result_ptr, m, n, k, 1); - }); - - return result; - } - - TORCH_IMPL_FUNC(mm_out_zoom)(const Tensor& self, const Tensor& mat2, const Tensor& result) - { - mm_out_hip_impl(const_cast(result), self, mat2); - } - -} // at::native - - diff --git a/aten/src/ATen/native/zoom/DistributionRandomKernel.cu b/aten/src/ATen/native/zoom/DistributionRandomKernel.cu new file mode 100644 index 00000000000000..7e8aa20d652bae --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionRandomKernel.cu @@ -0,0 +1,27 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + +namespace at::native { + +void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::random_from_to_kernel(iter, range, base, gen); +} + +void random_full_64_bits_range_kernel(TensorIteratorBase& iter, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::random_full_64_bits_range_kernel(iter, gen); +} + +void random_kernel(TensorIteratorBase& iter, std::optional gen_) { + auto gen = get_generator_or_default(gen_, zoom::detail::getDefaultZoomGenerator()); + at::native::templates::zoom::random_kernel(iter, gen); +} + +REGISTER_PRIVATEUSE1_DISPATCH(random_from_to_stub, &random_from_to_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(random_stub, &random_kernel); +REGISTER_PRIVATEUSE1_DISPATCH(random_full_64_bits_range_stub, &random_full_64_bits_range_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/zoom/DistributionUniform.cu b/aten/src/ATen/native/zoom/DistributionUniform.cu new file mode 100644 index 00000000000000..25ed5e7b8b1148 --- /dev/null +++ b/aten/src/ATen/native/zoom/DistributionUniform.cu @@ -0,0 +1,15 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include + +namespace at::native { + +void uniform_kernel(TensorIteratorBase& iter, double from, double to, std::optional gen) { + auto generator = get_generator_or_default(gen, zoom::detail::getDefaultZoomGenerator()); + templates::zoom::uniform_kernel(iter, from, to, generator); +} + +REGISTER_PRIVATEUSE1_DISPATCH(uniform_stub, &uniform_kernel); + +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/HIPbmm.cu b/aten/src/ATen/native/zoom/HIPbmm.cu deleted file mode 100644 index a77a31efaf1af6..00000000000000 --- a/aten/src/ATen/native/zoom/HIPbmm.cu +++ /dev/null @@ -1,132 +0,0 @@ -#include -#include -#include -#include -#include - -namespace at::native { - - int num_threads() { - return 32; - } - - // Helper function to convert hip_bfloat16 to float - __device__ float bfloat16_to_float(hip_bfloat16 a) { - union { - uint32_t int32; - float float32; - } u = {uint32_t(a.data) << 16}; - return u.float32; - } - - // Helper function to convert float to hip_bfloat16 - __device__ hip_bfloat16 float_to_bfloat16(float a) { - union { - float float32; - uint32_t int32; - } u = {a}; - hip_bfloat16 b; - b.data = uint16_t(u.int32 >> 16); - return b; - } - - template - __device__ float convert_to_float(T a) { - return a; - } - - template <> - __device__ float convert_to_float(hip_bfloat16 a) { - return bfloat16_to_float(a); - } - - template <> - __device__ float convert_to_float<__half>( __half a) { - return __half2float(a); - } - - template - __device__ T convert_from_float(float a) { - return static_cast(a); - } - - template <> - __device__ hip_bfloat16 convert_from_float(float a) { - return float_to_bfloat16(a); - } - - template <> - __device__ __half convert_from_float<__half>(float a) { - return __float2half(a); - } - - - template - __global__ void batched_matmul_kernel(const T* A, const T* B, T* C, - int M, int N, int K, int batch_size) { - int row = blockIdx.y * blockDim.y + threadIdx.y; - int col = blockIdx.x * blockDim.x + threadIdx.x; - int batch = blockIdx.z; - - if (row < M && col < K && batch < batch_size) { - float sum = 0.0f; - for (int n = 0; n < N; ++n) { - sum += convert_to_float(A[batch * M * N + row * N + n]) * - convert_to_float(B[batch * N * K + n * K + col]); - } - C[batch * M * K + row * K + col] = convert_from_float(sum); - } - } - - template - void batched_matmul(const T* A, const T* B, T* C, - int M, int N, int K, int batch_size) { - dim3 threadsPerBlock(num_threads(), num_threads()); - dim3 numBlocks((K + threadsPerBlock.x - 1) / threadsPerBlock.x, - (M + threadsPerBlock.y - 1) / threadsPerBlock.y, - batch_size); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(batched_matmul_kernel), numBlocks, threadsPerBlock, 0, 0, - A, B, C, M, N, K, batch_size); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - } - - // Specialization for at::Half - template <> - void batched_matmul(const at::Half* A, const at::Half* B, at::Half* C, - int M, int N, int K, int batch_size) { - dim3 threadsPerBlock(num_threads(), num_threads()); - dim3 numBlocks((K + threadsPerBlock.x - 1) / threadsPerBlock.x, - (M + threadsPerBlock.y - 1) / threadsPerBlock.y, - batch_size); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(batched_matmul_kernel<__half>), numBlocks, threadsPerBlock, 0, 0, - reinterpret_cast(A), - reinterpret_cast(B), - reinterpret_cast<__half*>(C), - M, N, K, batch_size); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - } - - // Specialization for at::BFloat16 - template <> - void batched_matmul(const at::BFloat16* A, const at::BFloat16* B, at::BFloat16* C, - int M, int N, int K, int batch_size) { - dim3 threadsPerBlock(num_threads(), num_threads()); - dim3 numBlocks((K + threadsPerBlock.x - 1) / threadsPerBlock.x, - (M + threadsPerBlock.y - 1) / threadsPerBlock.y, - batch_size); - - hipLaunchKernelGGL(HIP_KERNEL_NAME(batched_matmul_kernel), numBlocks, threadsPerBlock, 0, 0, - reinterpret_cast(A), - reinterpret_cast(B), - reinterpret_cast(C), - M, N, K, batch_size); - C10_ZOOM_KERNEL_LAUNCH_CHECK(); - } - - // Explicit instantiations for supported types - template void batched_matmul(const float*, const float*, float*, int, int, int, int); - template void batched_matmul(const double*, const double*, double*, int, int, int, int); - -} // at::native \ No newline at end of file diff --git a/aten/src/ATen/native/zoom/TensorCompare.cu b/aten/src/ATen/native/zoom/TensorCompare.cu new file mode 100644 index 00000000000000..e92d058c9b7222 --- /dev/null +++ b/aten/src/ATen/native/zoom/TensorCompare.cu @@ -0,0 +1,133 @@ +#define TORCH_ASSERT_NO_OPERATORS +#include +#include +#include +#include +#include +#include + + +namespace at::native { + +namespace { + +void where_kernel_impl(TensorIterator &iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_zoom", [&] { + gpu_kernel( + iter, + [=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t { + return cond_val ? self_val : other_val; + }); + }); +} + +void isposinf_kernel_impl(TensorIteratorBase &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isposinf_zoom", [&]() { + gpu_kernel( + iter, + [] GPU_LAMBDA (scalar_t a) -> bool { return a == std::numeric_limits::infinity(); } + ); + }); +} + +void isneginf_kernel_impl(TensorIteratorBase &iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.input_dtype(), "isneginf_zoom", [&]() { + gpu_kernel( + iter, + [] GPU_LAMBDA (scalar_t a) -> bool { return a == -std::numeric_limits::infinity(); } + ); + }); +} + +void clamp_kernel_impl(TensorIteratorBase& iter) { + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "clamp_zoom", [&] { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t v, scalar_t lower, scalar_t upper) -> scalar_t { + // Propagate nan, which doesn't propagate automatically for ROCm + if (at::_isnan(v)) { + return v; + } if (at::_isnan(lower)) { + return lower; + } if (at::_isnan(upper)) { + return upper; + } else { + return ::min(::max(v, lower), upper); + } + }); + }); +} + +void inline launch_clamp_scalar(TensorIteratorBase& iter, Scalar lim0, Scalar lim1, at::native::detail::ClampLimits minmax){ + AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.common_dtype(), "clamp_scalar_zoom", [&] { + using opmath_t = at::opmath_type; + auto lim0_val = lim0.to(); + auto lim1_val = lim1.to(); + + gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t { + // Propagate nan, which doesn't propagate automatically for ROCm + if (_isnan(static_cast(v))) { + return v; + } else if (minmax==at::native::detail::ClampLimits::Min){ + return ::max(static_cast(v), lim0_val); + } else if (minmax==at::native::detail::ClampLimits::Max){ + return ::min(static_cast(v), lim0_val); + } else { + return ::min(::max(static_cast(v), lim0_val), lim1_val); + } + }); + }); +} + + +void clamp_scalar_kernel_impl(TensorIteratorBase& iter, const Scalar& min, const Scalar& max) { + launch_clamp_scalar(iter, min, max, at::native::detail::ClampLimits::MinMax); +} + +void clamp_min_scalar_kernel_impl(TensorIteratorBase& iter, Scalar min) { + launch_clamp_scalar(iter, min, min, at::native::detail::ClampLimits::Min); +} + +void clamp_max_scalar_kernel_impl(TensorIteratorBase& iter, Scalar max) { + launch_clamp_scalar(iter, max, max, at::native::detail::ClampLimits::Max); +} + +} // anonymous namespace + + +REGISTER_PRIVATEUSE1_DISPATCH(where_kernel, &where_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(isposinf_stub, &isposinf_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(isneginf_stub, &isneginf_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_stub, &clamp_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_scalar_stub, &clamp_scalar_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_min_scalar_stub, &clamp_min_scalar_kernel_impl); +REGISTER_PRIVATEUSE1_DISPATCH(clamp_max_scalar_stub, &clamp_max_scalar_kernel_impl); + +template +__global__ void _assert_async_zoom_kernel(const scalar_t* input) { + ZOOM_KERNEL_ASSERT(input[0] != 0); +} + +__global__ void _assert_async_zoom_kernel(const c10::complex* input) { + ZOOM_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); +} +__global__ void _assert_async_zoom_kernel(const c10::complex* input) { + ZOOM_KERNEL_ASSERT(input[0] != c10::complex(0, 0)); +} + +void _assert_async_zoom(const Tensor& self_tensor) { + const TensorBase &self = get_tensor_base(self_tensor); + auto n = self.numel(); + TORCH_CHECK(n != 0, "Boolean value of Tensor with no values is ambiguous"); + TORCH_CHECK(n < 2, "Boolean value of Tensor with more than one value is ambiguous"); + auto stream = c10::zoom::getCurrentZoomStream(); + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_assert_async_zoom", [&] { + _assert_async_zoom_kernel<<<1, 1, 0, stream>>>(self.const_data_ptr()); + C10_ZOOM_KERNEL_LAUNCH_CHECK(); + }); +} + +// TODO (tmanlaibaatar) Ignore assert msg for now +void _assert_async_msg_zoom(const Tensor& self_tensor, c10::string_view assert_msg) { + _assert_async_zoom(self_tensor); +} + +} // namespace at::native \ No newline at end of file diff --git a/test/test_ops.py b/test/test_ops.py index 44f503ae9b6ed8..cd473ac92c4f4f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -32,6 +32,7 @@ instantiate_device_type_tests, onlyCPU, onlyCUDA, + onlyCUDAAndZOOM, onlyNativeDeviceTypes, OpDTypes, ops, @@ -283,7 +284,7 @@ def test_numpy_ref(self, device, dtype, op): ) # Tests that the cpu and gpu results are consistent - @onlyCUDA + @onlyCUDAAndZOOM @suppress_warnings @slowTest @ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one) diff --git a/torch/zoom/__init__.py b/torch/zoom/__init__.py index 7b5a757d08520c..debc3c917f96ae 100644 --- a/torch/zoom/__init__.py +++ b/torch/zoom/__init__.py @@ -44,7 +44,7 @@ def _maybe_exchange_device(device: int) -> int: return -1 raise RuntimeError("PyTorch was compiled without Zoom support") - +from .zoom_triton_mm import * _initialized = False _tls = threading.local() diff --git a/torch/zoom/zoom_triton_mm.py b/torch/zoom/zoom_triton_mm.py new file mode 100644 index 00000000000000..6967ed7f8c1a77 --- /dev/null +++ b/torch/zoom/zoom_triton_mm.py @@ -0,0 +1,182 @@ +import torch +import triton +import triton.language as tl +from torch.library import register_kernel +torch.utils.rename_privateuse1_backend('zoom') + +@triton.heuristics({ + 'BLOCK_SIZE_M': lambda args: 128, + 'BLOCK_SIZE_N': lambda args: 64, + 'BLOCK_SIZE_K': lambda args: 32, + 'GROUP_SIZE_M': lambda args: 32, + 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, +}) +@triton.jit +def batched_matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + B, + M, + N, + K, + stride_ab, + stride_am, + stride_ak, + stride_bb, + stride_bk, + stride_bn, + stride_cb, + stride_cm, + stride_cn, + a_scale_ptr, + b_scale_ptr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + EVEN_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + APPLY_SCALE: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the batched matmul C = A x B. + A has shape (B, M, K), B has shape (B, K, N) and C has shape (B, M, N) + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_batch = num_pid_m * num_pid_n + batch_id = pid // num_pid_in_batch + pid_in_batch = pid % num_pid_in_batch + + if GROUP_SIZE_M == 1: + pid_m = pid_in_batch // num_pid_n + pid_n = pid_in_batch % num_pid_n + else: + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid_in_batch // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid_in_batch % group_size_m) + pid_n = (pid_in_batch % num_pid_in_group) // group_size_m + + offs_k = tl.arange(0, BLOCK_SIZE_K) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + a_ptrs = a_ptr + batch_id * stride_ab + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + batch_id * stride_bb + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + if APPLY_SCALE: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + + acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32 + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if APPLY_SCALE: + accumulator = accumulator * a_scale * b_scale + + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(c_ptr.type.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + batch_id * stride_cb + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + +# Wrapper for batched gemm kernel +def batched_matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""): + assert a.shape[2] == b.shape[1], "Incompatible matrix dimensions!!!" + assert a.shape[0] == b.shape[0], "Incompatible batch dimensions!!!" + assert a.dtype == b.dtype, "Mixed dtype GEMMs are not supported!!!" + B, M, K = a.shape + _, K, N = b.shape + grid = lambda META: (B * triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + batched_matmul_kernel[grid]( + a, + b, + c, + B, + M, + N, + K, + a.stride(0), + a.stride(1), + a.stride(2), + b.stride(0), + b.stride(1), + b.stride(2), + c.stride(0), + c.stride(1), + c.stride(2), + a_scale, + b_scale, + APPLY_SCALE=scale_a8_b8, + ACTIVATION=activation, + ) + +# Activation function. +@triton.jit +def leaky_relu(x): + x = x + 1 + return tl.where(x >= 0, x, 0.01 * x) + +name_to_torch_types = { + 'int8': torch.int8, + 'int32': torch.int32, + 'fp16': torch.float16, + 'fp32': torch.float32, + 'bf16': torch.bfloat16, + 'fp8e5': torch.float8_e5m2fnuz, + 'fp8e4': torch.float8_e4m3fnuz, +} + +dtype_max = { + dtype: (torch.finfo(dtype) if dtype.is_floating_point else torch.iinfo(dtype)).max + for dtype in [ + torch.float8_e5m2fnuz, + torch.float8_e4m3fnuz, + torch.int8, + ] +} + +def mm_out_zoom(self, mat2, out): + batched_matmul(self.unsqueeze(0), mat2.unsqueeze(0), out.unsqueeze(0), None, None, False) + +def bmm_out_zoom(self, mat2, out): + batched_matmul(self, mat2, out, None, None, False) + +@register_kernel("aten::mm.out", "zoom") +def mm_out(self, mat2, out): + mm_out_zoom(self, mat2, out) + +@register_kernel("aten::mm", "zoom") +def mm(self, mat2): + out = self.new_empty((self.size(0), mat2.size(1))) + mm_out_zoom(self, mat2, out) + return out + +@register_kernel("aten::bmm.out", "zoom") +def bmm_out(self, mat2, out): + bmm_out_zoom(self, mat2, out) + +@register_kernel("aten::bmm", "zoom") +def bmm(self, mat2): + out = self.new_empty((self.size(0), self.size(1), mat2.size(2))) + bmm_out_zoom(self, mat2, out) + return out + \ No newline at end of file