From b7d31bdf90c19e19147527c6b956f7896c2aaddf Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Thu, 21 Nov 2024 15:47:47 +0000 Subject: [PATCH] change the location of gemm runner for Batched GEMM --- .../compiler/xla/stream_executor/stream.cc | 72 +++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/tensorflow/compiler/xla/stream_executor/stream.cc b/tensorflow/compiler/xla/stream_executor/stream.cc index 2080bff5b004dc..097f6203c79b89 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.cc +++ b/tensorflow/compiler/xla/stream_executor/stream.cc @@ -1703,12 +1703,6 @@ Stream &Stream::ThenBlasGemmBatched( uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, blas::CallContext context) { - if (gpu::GpuBlasLtEnabled()) { - auto &r = gpu::BlasLtGemmRunner::i(this); - CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc, batch_count, /* allocator */nullptr)); - return *this; - } return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, /*scratch_allocator=*/nullptr, context); @@ -1724,7 +1718,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, int, DeviceMemorySlice, int, float, @@ -1744,7 +1743,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, int, DeviceMemorySlice, int, float, @@ -1762,12 +1766,6 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, blas::CallContext context) { - if (gpu::GpuBlasLtEnabled()) { - auto &r = gpu::BlasLtGemmRunner::i(this); - CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc, batch_count, /* allocator */nullptr)); - return *this; - } return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, /*scratch_allocator=*/nullptr, context); @@ -1782,7 +1780,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, int, DeviceMemorySlice, int, float, DeviceMemorySlice, int, int, @@ -1800,12 +1803,6 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, int batch_count, blas::CallContext context) { - if (gpu::GpuBlasLtEnabled()) { - auto &r = gpu::BlasLtGemmRunner::i(this); - CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc, batch_count, /* allocator */nullptr)); - return *this; - } return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, /*scratch_allocator=*/nullptr, context); @@ -1820,7 +1817,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, int, DeviceMemorySlice, int, double, @@ -1837,12 +1839,6 @@ Stream &Stream::ThenBlasGemmBatched( DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, blas::CallContext context) { - if (gpu::GpuBlasLtEnabled()) { - auto &r = gpu::BlasLtGemmRunner::i(this); - CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc, batch_count, /* allocator */nullptr)); - return *this; - } return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, /*scratch_allocator=*/nullptr, context); @@ -1858,7 +1854,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, DeviceMemorySlice>, int, DeviceMemorySlice>, int, std::complex, @@ -1877,12 +1878,6 @@ Stream &Stream::ThenBlasGemmBatched( DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, blas::CallContext context) { - if (gpu::GpuBlasLtEnabled()) { - auto &r = gpu::BlasLtGemmRunner::i(this); - CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, - ldb, beta, c, ldc, batch_count, /* allocator */nullptr)); - return *this; - } return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, /*scratch_allocator=*/nullptr, context); @@ -1899,7 +1894,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); - + if (gpu::GpuBlasLtEnabled()) { + auto &r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, batch_count, scratch_allocator)); + return *this; + } ThenBlasImpl, DeviceMemorySlice>, int, DeviceMemorySlice>, int,