From fb3e03cba5cde8cd06859ec5e16667b1453bc5a6 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 18 Nov 2022 20:40:30 +0530 Subject: [PATCH 01/48] add cutlass source files for fusedL2NN with initial set of changes --- .../raft/distance/detail/fused_l2_nn.cuh | 12 + .../detail/fused_l2_nn_cutlass_base.cuh | 179 ++++++ .../distance/detail/fused_l2_nn_epilogue.cuh | 102 +++ .../fused_l2_nn_epilogue_elementwise.cuh | 139 +++++ .../raft/distance/detail/fused_l2_nn_gemm.h | 239 +++++++ .../predicated_tile_iterator_reduced_vec.h | 581 ++++++++++++++++++ 6 files changed, 1252 insertions(+) create mode 100755 cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh create mode 100755 cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh create mode 100755 cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh create mode 100755 cpp/include/raft/distance/detail/fused_l2_nn_gemm.h create mode 100755 cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 1385d0aa09..e0ca99478b 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -258,6 +258,18 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, obj.run(); } +// final op functor for FusedL2NN used in its cutlass version +// to convert the distance value & key(loc id) into key-value pair +template +struct kvp_fin_op { + __host__ __device__ kvp_fin_op() noexcept {}; + // functor signature. + __host__ __device__ OutType operator()(AccType d_val, Index idx) const noexcept + { + return OutType(d_val, idx); + } +}; + template + +#include +#include +#include + +#include +#include +#include +#include + +#include "./pairwise_distance_epilogue_elementwise.h" +#include "./pairwise_distance_gemm.h" + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +namespace raft { +namespace distance { +namespace detail { + +template +void cutlassFusedL2NNKernel(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + int* mutexes, + FinalLambda fin_op, + DistanceFn dist_op, + cudaStream_t stream) +{ + static_assert(!(std::is_same::value), + "OutType bool is not supported use uint8_t instead"); + + using EpilogueOutputOp = + cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise; + constexpr int batch_count = 1; + + constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; + + typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op); + + const DataT *a, *b; + + IdxT gemm_lda, gemm_ldb; + + // Number of pipelines you want to use + constexpr int NumStages = 3; + // Alignment + constexpr int Alignment = VecLen; + + // default initialize problem size with row major inputs + auto problem_size = cutlass::gemm::GemmCoord(n, m, k); + + using cutlassDistKernel = + typename cutlass::gemm::kernel::PairwiseDistanceGemm::GemmKernel; + + using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; + + if constexpr (isRowMajor) { + a = y; + b = x; + gemm_lda = ldb; + gemm_ldb = lda; + } else { + problem_size = cutlass::gemm::GemmCoord(m, n, k); + a = x; + b = y; + gemm_lda = lda; + gemm_ldb = ldb; + } + + typename cutlassDist::Arguments arguments{ + mode, problem_size, batch_count, epilog_op_param, a, b, + xn, // C matrix eq vector param, which here is A norm + nullptr, // tensor_Z, + (DataT*)yn, // this is broadcast vec, which is required to be non-const param + dOutput, // Output distance matrix + (int64_t)0, // batch stride A + (int64_t)0, // batch stride B + (int64_t)0, // batch stride Norm A + (int64_t)0, + (int64_t)0, // batch stride Norm B + (int64_t)0, // batch stride Output + gemm_lda, // stride A + gemm_ldb, // stride B + 1, // stride A norm + 0, // this is no-op for Z + 0, // This must be zero + ldd // stride Output matrix + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = cutlassDist::get_workspace_size(arguments); + // Allocate workspace memory + rmm::device_uvector workspace(workspace_size, stream); + // Instantiate CUTLASS kernel depending on templates + cutlassDist cutlassDist_op; + // Check the problem size is supported or not + cutlass::Status status = cutlassDist_op.can_implement(arguments); + CUTLASS_CHECK(status); + // Initialize CUTLASS kernel with arguments and workspace pointer + status = cutlassDist_op.initialize(arguments, workspace.data(), stream); + CUTLASS_CHECK(status); + // Launch initialized CUTLASS kernel + status = cutlassDist_op(); + CUTLASS_CHECK(status); +} + +}; // namespace detail +}; // namespace distance +}; // namespace raft +#endif // (__CUDACC_VER_MAJOR__ < 12) +#pragma GCC diagnostic pop diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh new file mode 100755 index 0000000000..11c8f5482b --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75) + +This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec +and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise +operation. +-- A norm load is provided PredicatedTileIteratorNormVec +-- B norm load is provided by EpilogueWithBroadcast +-- elementwise operation is provided by OutputOp +*/ + +#pragma once + +#include +#include +#include + +#include + +#include "./predicated_tile_iterator_normvec.h" +#include "./predicated_tile_iterator_reduced_vec.h" +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template +struct FusedL2NNEpilogue { + /// Use defaults related to the existing epilogue + using Base = + DefaultEpilogueTensorOp; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using RowNormTileIterator = cutlass::epilogue::threadblock:: + PredicatedTileIteratorNormVec; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using OutputTileIterator = + cutlass::epilogue::threadblock::PredicatedTileIteratorReducedVec; + + /// Define the epilogue + using Epilogue = EpilogueWithBroadcast; +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh new file mode 100755 index 0000000000..5e9d48ee7c --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +namespace detail { + + +// TODO: specialize this function for MinAndDistanceReduceOp +// with atomicCAS of 64 bit which will eliminate mutex and shfls +template +DI void updateReducedVal( + int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) +{ + const auto lid = threadIdx.x % raft::WarpSize; + const auto accrowid = threadIdx.x / P::AccThCols; + + // Update each output row in order within a warp. This will resolve hang + // issues with pre-Volta architectures +#pragma unroll + for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { + if (lid == j * P::AccThCols) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rid = gridStrideY + accrowid + i * P::AccThRows; + if (rid < m) { + auto value = val[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + red_op(rid, min + rid, value); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + } +} + +template +void fusedL2NNImpl(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + int* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + // The kernel policy is determined by fusedL2NN. + typedef Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef KeyValuePair KVPair; + + // Accumulation operation lambda + auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; + + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; + + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + if (sqrt) { + auto fusedL2NNSqrt = fusedL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); + + fusedL2NNSqrt<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + } else { + auto fusedL2NN = fusedL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); + fusedL2NN<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + } + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h new file mode 100755 index 0000000000..1d380bfbbf --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include + +#include "./fused_l2_nn_epilogue.cuh" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Element type for final output + // typename ElementOutT, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct FusedL2NNGemm { + // This struct is specialized for fp32/3xTF32 + + /// Threadblock-level tile size (concept: GemmShape) + using ThreadblockShape = + cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = + cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastF32; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::FusedL2NNEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementAccumulator, + typename EpilogueOutputOp::ElementT, + ElementAccumulator, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue; +}; + +template < + /// Layout type for A matrix operand + int kAlignmentA, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct FusedL2NNGemm { + // using Transform = cutlass::ComplexTransform::kNone; + // Threadblock-level tile size (concept: GemmShape) + using ThreadblockShape = + cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 64, N = 64, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 32, N = 32, K = 16 + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + // Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAdd; + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::FusedL2NNEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC_, + typename EpilogueOutputOp::ElementT, + ElementC_, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = GemmWithFusedEpilogue; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h new file mode 100755 index 0000000000..e56b34c3b0 --- /dev/null +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -0,0 +1,581 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) + +Changes: +- added `Layout_` template param +- Only the row index is used to load the data in load_with_byte_offset(). + This way the same normalization data is used across all columns in a row. + +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template +class PredicatedTileIteratorReducedVec { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = Layout_; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorReducedVec(PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { mask_.clear(); } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride); + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[0], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) row_add_P = 0; + if (output_Q > convolution_Q - 2) row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorReducedVec& operator++() + { + ++state_[0]; + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// From 7039a50033a98c439a2d89d63a64b07f14c1a309 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 30 Nov 2022 22:03:43 +0530 Subject: [PATCH 02/48] temp commit to store current progress --- .../detail/fused_l2_nn_cutlass_base.cuh | 11 +- .../fused_l2_nn_epilogue_elementwise.cuh | 264 ++++++++++-------- .../raft/distance/detail/fused_l2_nn_gemm.h | 4 +- .../predicated_tile_iterator_reduced_vec.h | 9 +- 4 files changed, 166 insertions(+), 122 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh index 8f688ef316..31b44ab81f 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh @@ -62,7 +62,8 @@ template + typename ReduceOpT, + typename KVPReduceOpT> void cutlassFusedL2NNKernel(const DataT* x, const DataT* y, const DataT* xn, @@ -77,6 +78,8 @@ void cutlassFusedL2NNKernel(const DataT* x, int* mutexes, FinalLambda fin_op, DistanceFn dist_op, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, cudaStream_t stream) { static_assert(!(std::is_same::value), @@ -109,14 +112,18 @@ void cutlassFusedL2NNKernel(const DataT* x, // default initialize problem size with row major inputs auto problem_size = cutlass::gemm::GemmCoord(n, m, k); + constexpr bool isRowMajor = true; + using cutlassDistKernel = - typename cutlass::gemm::kernel::PairwiseDistanceGemm::GemmKernel; diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh index 5e9d48ee7c..2c31bf38da 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,126 +14,158 @@ * limitations under the License. */ +// +/*! \file + \brief Functor performing distance operations used by epilogues of pairwise distance + * kernels. +* This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0 +* customized for applying elementwise distance formula on accumulated GEMM value +* and applying user-defined final custom operation on the distance value. +*/ + #pragma once -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace distance { - -namespace detail { - - -// TODO: specialize this function for MinAndDistanceReduceOp -// with atomicCAS of 64 bit which will eliminate mutex and shfls -template -DI void updateReducedVal( - int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) -{ - const auto lid = threadIdx.x % raft::WarpSize; - const auto accrowid = threadIdx.x / P::AccThCols; - - // Update each output row in order within a warp. This will resolve hang - // issues with pre-Volta architectures -#pragma unroll - for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { - if (lid == j * P::AccThCols) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rid = gridStrideY + accrowid + i * P::AccThRows; - if (rid < m) { - auto value = val[i]; - while (atomicCAS(mutex + rid, 0, 1) == 1) - ; - __threadfence(); - red_op(rid, min + rid, value); - __threadfence(); - atomicCAS(mutex + rid, 1, 0); - } - } - } +#include +#include +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template +class FusedL2NNEpilogueElementwise { + public: + using ElementOutput = ElementC_; + using ElementC = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + + using DistanceOp = DistanceOp_; + using FinalOp = FinalOp_; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using FragmentT = Array; + + using FragmentOutput = FragmentZ; + + static bool const kIsHeavy = false; // ElementwiseOp::kIsHeavy; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = false; // We don't store anything in Z, + + /// If true, the 'T' tensor is stored + static bool const kStoreT = true; // this is our final output storage. + + /// Host-constructable parameters structure + struct Params { + FinalOp_ final_op_; + DistanceOp_ dist_op_; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(DistanceOp_ dist_op, FinalOp final_op) : final_op_(final_op), dist_op_(dist_op) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + // + // Data members + // + FinalOp_ final_op; + DistanceOp_ elementwise_op; + + public: + // + // Methods + // + + /// Constructor from Params + CUTLASS_HOST_DEVICE + FusedL2NNEpilogueElementwise(Params const& params) + : final_op(params.final_op_), elementwise_op(params.dist_op_) + { + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const + { + // we use for making sure C matrix path is used for A mat norm. + return true; } -} - -template -void fusedL2NNImpl(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - int* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - // The kernel policy is determined by fusedL2NN. - typedef Policy P; - - dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); - constexpr auto maxVal = std::numeric_limits::max(); - typedef KeyValuePair KVPair; - - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); - if (initOutBuffer) { - initKernel - <<>>(min, m, maxVal, redOp); - RAFT_CUDA_TRY(cudaGetLastError()); + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentC const& frag_C, + FragmentCompute const& V) const + { + FragmentCompute tmp_Accum = + NumericArrayConverter()(AB); + FragmentCompute tmp_C = + NumericArrayConverter()(frag_C); + FragmentCompute result_Z; + FragmentCompute result_T; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerAccess; ++i) { + result_Z[i] = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); + result_T[i] = final_op(result_Z[i], 0); + } + + NumericArrayConverter convert_t; + frag_T = convert_t(result_T); } - auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; - - constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - if (sqrt) { - auto fusedL2NNSqrt = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - - fusedL2NNSqrt<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); - } else { - auto fusedL2NN = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentCompute const& V) const + { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// - RAFT_CUDA_TRY(cudaGetLastError()); -} +} // namespace thread +} // namespace epilogue +} // namespace cutlass -} // namespace detail -} // namespace distance -} // namespace raft +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h index 1d380bfbbf..c738572863 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -50,6 +50,8 @@ template < // typename ElementOutT, /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' typename EpilogueOutputOp, + typename ReduceOpT, + typename KVPReduceOpT, /// Number of stages used in the pipelined mainloop int Stages, /// data layout row/column major of inputs @@ -66,7 +68,7 @@ struct FusedL2NNGemm { /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op using InstructionShape = - cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 + cutlass::gemm::GemmShape<16, 8, 4>; // <- MMA Op tile M = 16, N = 8, K = 8 /// Operation performed by GEMM using Operator = cutlass::arch::OpMultiplyAddFastF32; diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index e56b34c3b0..f00d098cd7 100755 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -60,6 +60,7 @@ namespace threadblock { template class PredicatedTileIteratorReducedVec { @@ -76,6 +77,7 @@ class PredicatedTileIteratorReducedVec { using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; using TensorCoord = MatrixCoord; + using EpilogueOpParams = EpilogueOpParams_ static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; static int const kThreads = ThreadMap::kThreads; @@ -103,6 +105,7 @@ class PredicatedTileIteratorReducedVec { struct Params : PredicatedTileIteratorParams { using Base = PredicatedTileIteratorParams; + EpilogueOpParams user_param; CUTLASS_HOST_DEVICE Params() {} @@ -156,7 +159,7 @@ class PredicatedTileIteratorReducedVec { // /// Parameters structure containing reference and precomputed state. - PredicatedTileIteratorParams params_; + Params params_; /// Byte-level pointer uint8_t* byte_pointer_; @@ -188,7 +191,7 @@ class PredicatedTileIteratorReducedVec { static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + static_assert(sizeof(Params::stride) == 8, "Expected 64b strides"); private: // @@ -202,7 +205,7 @@ class PredicatedTileIteratorReducedVec { /// Constructor CUTLASS_DEVICE - PredicatedTileIteratorReducedVec(PredicatedTileIteratorParams const& params, + PredicatedTileIteratorReducedVec(Params const& params, Element* pointer, TensorCoord extent, int thread_idx, From 356c2e1312faee424845b9a251df8c759a633e2c Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 21 Dec 2022 22:03:18 +0530 Subject: [PATCH 03/48] working code with atomicCAS based reduction from each thread --- .../raft/distance/detail/fused_l2_nn.cuh | 107 ++++++++++++------ .../detail/fused_l2_nn_cutlass_base.cuh | 37 +++--- .../distance/detail/fused_l2_nn_epilogue.cuh | 2 +- .../fused_l2_nn_epilogue_elementwise.cuh | 31 +++-- .../raft/distance/detail/fused_l2_nn_gemm.h | 11 +- .../predicated_tile_iterator_reduced_vec.h | 75 +++++++++--- .../neighbors/detail/connect_components.cuh | 7 +- 7 files changed, 187 insertions(+), 83 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index e0ca99478b..396d100c1c 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -22,6 +22,8 @@ #include #include #include +#include +#include namespace raft { namespace distance { @@ -43,7 +45,7 @@ struct KVPMinReduceImpl { template struct MinAndDistanceReduceOpImpl { typedef typename raft::KeyValuePair KVP; - DI void operator()(LabelT rid, KVP* out, const KVP& other) + DI void operator()(LabelT rid, KVP* out, const KVP& other) const { if (other.value < out->value) { out->key = other.key; @@ -51,17 +53,30 @@ struct MinAndDistanceReduceOpImpl { } } - DI void operator()(LabelT rid, DataT* out, const KVP& other) + DI void operator()(LabelT rid, DataT* out, const KVP& other) const { if (other.value < *out) { *out = other.value; } } - DI void init(DataT* out, DataT maxVal) { *out = maxVal; } - DI void init(KVP* out, DataT maxVal) + DI void operator()(LabelT rid, DataT* out, const DataT& other) const + { + if (other < *out) { *out = other; } + } + + + DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } + DI void init(KVP* out, DataT maxVal) const { out->key = 0; out->value = maxVal; } + + DI void init_key(DataT *out, LabelT idx) const { return; } + DI void init_key(KVP *out, LabelT idx) const + { + out->key = idx; + //out->value = maxVal; + } }; template @@ -260,13 +275,21 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, // final op functor for FusedL2NN used in its cutlass version // to convert the distance value & key(loc id) into key-value pair -template +template struct kvp_fin_op { + typedef typename raft::KeyValuePair KVP; + __host__ __device__ kvp_fin_op() noexcept {}; // functor signature. - __host__ __device__ OutType operator()(AccType d_val, Index idx) const noexcept + __host__ __device__ void operator()(KVP &a, AccType d_val, Index idx) const + { + a.value = d_val; + a.key = idx; + return; + } + __host__ __device__ void operator()(AccType &a, AccType d_val, Index idx) const { - return OutType(d_val, idx); + return; } }; @@ -311,34 +334,50 @@ void fusedL2NNImpl(OutT* min, auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; - constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - if (sqrt) { - auto fusedL2NNSqrt = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - - fusedL2NNSqrt<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + const auto deviceVersion = getComputeCapability(); + if (deviceVersion.first >= 8) { + using L2Op = L2ExpandedOp; + using final_op_kvp_ = kvp_fin_op; + final_op_kvp_ fin_op_kvp; + L2Op L2_dist_op(sqrt); + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedL2NNKernel(x, y, xn, yn, m, n, k, + lda, ldb, ldd, min, workspace, fin_op_kvp, L2_dist_op, + redOp, pairRedOp, stream); } else { - auto fusedL2NN = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + if (sqrt) { + auto fusedL2NNSqrt = fusedL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); + + fusedL2NNSqrt<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + } else { + auto fusedL2NN = fusedL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); + fusedL2NN<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); + } } RAFT_CUDA_TRY(cudaGetLastError()); diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh index 31b44ab81f..da2526b796 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh @@ -38,8 +38,9 @@ #include #include -#include "./pairwise_distance_epilogue_elementwise.h" -#include "./pairwise_distance_gemm.h" +#include "./fused_l2_nn_epilogue_elementwise.cuh" +#include "./fused_l2_nn_gemm.h" + #define CUTLASS_CHECK(status) \ { \ @@ -82,23 +83,25 @@ void cutlassFusedL2NNKernel(const DataT* x, KVPReduceOpT pairRedOp, cudaStream_t stream) { - static_assert(!(std::is_same::value), - "OutType bool is not supported use uint8_t instead"); + // static_assert(!(std::is_same::value), + // "OutType bool is not supported use uint8_t instead"); using EpilogueOutputOp = - cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise; + cutlass::epilogue::thread::FusedL2NNEpilogueElementwise; constexpr int batch_count = 1; constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; - typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op); + typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op, redOp, pairRedOp, mutexes); const DataT *a, *b; @@ -122,8 +125,6 @@ void cutlassFusedL2NNKernel(const DataT* x, AccT, AccT, EpilogueOutputOp, - ReduceOpT, - KVPReduceOpT, NumStages, // Number of pipeline stages isRowMajor>::GemmKernel; @@ -154,12 +155,12 @@ void cutlassFusedL2NNKernel(const DataT* x, (int64_t)0, (int64_t)0, // batch stride Norm B (int64_t)0, // batch stride Output - gemm_lda, // stride A - gemm_ldb, // stride B + (int64_t)gemm_lda, // stride A + (int64_t)gemm_ldb, // stride B 1, // stride A norm 0, // this is no-op for Z 0, // This must be zero - ldd // stride Output matrix + (int64_t)ldd // stride Output matrix }; // Using the arguments, query for extra workspace required for matrix multiplication computation diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh index 11c8f5482b..d507d59c5d 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh @@ -78,7 +78,7 @@ struct FusedL2NNEpilogue { // using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorReducedVec; + ElementTensor, LayoutT, typename OutputOp::Params>; /// Define the epilogue using Epilogue = EpilogueWithBroadcast + typename FinalOp_, + typename ReduceOpT_, + typename KVPReduceOpT_> class FusedL2NNEpilogueElementwise { public: using ElementOutput = ElementC_; @@ -85,12 +87,18 @@ class FusedL2NNEpilogueElementwise { struct Params { FinalOp_ final_op_; DistanceOp_ dist_op_; - + KVPReduceOpT_ pair_redop_; + ReduceOpT_ red_op_; + int *mutexes_; // // Methods // CUTLASS_HOST_DEVICE - Params(DistanceOp_ dist_op, FinalOp final_op) : final_op_(final_op), dist_op_(dist_op) {} + Params(DistanceOp_ dist_op, FinalOp final_op, + ReduceOpT_ red_op, KVPReduceOpT_ pair_redop, + int *mutexes) : + final_op_(final_op), dist_op_(dist_op), pair_redop_(pair_redop), + red_op_(red_op), mutexes_(mutexes) {} CUTLASS_HOST_DEVICE Params() {} @@ -102,6 +110,8 @@ class FusedL2NNEpilogueElementwise { // FinalOp_ final_op; DistanceOp_ elementwise_op; + KVPReduceOpT_ pair_redop; + ReduceOpT_ red_op; public: // @@ -111,7 +121,8 @@ class FusedL2NNEpilogueElementwise { /// Constructor from Params CUTLASS_HOST_DEVICE FusedL2NNEpilogueElementwise(Params const& params) - : final_op(params.final_op_), elementwise_op(params.dist_op_) + : final_op(params.final_op_), elementwise_op(params.dist_op_), + pair_redop(params.pair_redop_), red_op(params.red_op_) { } @@ -140,16 +151,18 @@ class FusedL2NNEpilogueElementwise { FragmentCompute tmp_C = NumericArrayConverter()(frag_C); FragmentCompute result_Z; - FragmentCompute result_T; + //FragmentT result_T; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerAccess; ++i) { - result_Z[i] = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); - result_T[i] = final_op(result_Z[i], 0); + //result_Z[i] = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); + ElementCompute res_Z = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); + //frag_T[i] = final_op(result_Z[i], 0); + red_op.init(&frag_T[i], res_Z); } - NumericArrayConverter convert_t; - frag_T = convert_t(result_T); + // NumericArrayConverter convert_t; + // frag_T = convert_t(result_T); } /// Applies the operation when is_source_needed() is false diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h index c738572863..9482062e29 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -19,11 +19,12 @@ #include #include -#include +//#include #include #include #include "./fused_l2_nn_epilogue.cuh" +#include "./fusedL2NN_gemm_with_fused_epilogue.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -50,8 +51,6 @@ template < // typename ElementOutT, /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' typename EpilogueOutputOp, - typename ReduceOpT, - typename KVPReduceOpT, /// Number of stages used in the pipelined mainloop int Stages, /// data layout row/column major of inputs @@ -132,7 +131,8 @@ struct FusedL2NNGemm { GemmBase::Epilogue::kElementsPerAccess>::Epilogue; // Compose the GEMM kernel - using GemmKernel = GemmWithFusedEpilogue; + //using GemmKernel = GemmWithFusedEpilogue; + using GemmKernel = FusedL2NNWithFusedEpilogue; }; template < @@ -231,7 +231,8 @@ struct FusedL2NNGemm::Epilogue; // Compose the GEMM kernel - using GemmKernel = GemmWithFusedEpilogue; + //using GemmKernel = GemmWithFusedEpilogue; + using GemmKernel = FusedL2NNWithFusedEpilogue; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index f00d098cd7..16696b2cf2 100755 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -77,9 +77,10 @@ class PredicatedTileIteratorReducedVec { using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; using TensorCoord = MatrixCoord; - using EpilogueOpParams = EpilogueOpParams_ + using EpilogueOpParams = EpilogueOpParams_; - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + //static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kElementsPerAccess = 1; static int const kThreads = ThreadMap::kThreads; static int const kIterations = ThreadMap::Count::kTile; @@ -89,13 +90,19 @@ class PredicatedTileIteratorReducedVec { static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); /// Fragment object + // using Fragment = Array; + using Fragment = Array; + kElementsPerAccess>; /// Memory access size - using AccessType = AlignedArray; + //using AccessType = AlignedArray; + using AccessType = AlignedArray; // // Parameters struct @@ -117,6 +124,16 @@ class PredicatedTileIteratorReducedVec { { } + CUTLASS_HOST_DEVICE + Params(Layout const& layout, EpilogueOpParams const& user_param_) + : PredicatedTileIteratorParams( + //layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()), + user_param(user_param_) + { + } + CUTLASS_HOST_DEVICE Params(Base const& base) : Base(base) {} }; @@ -237,6 +254,9 @@ class PredicatedTileIteratorReducedVec { byte_pointer_ = reinterpret_cast(pointer) + LongIndex(thread_offset.row()) * LongIndex(params_.stride); + // printf("blockId = %d threadId = %d thread_offset_row = %d stride = %d extent.row() = %d extent.column() = %d\n", + // (int)blockIdx.x, (int)threadIdx.x, (int)thread_offset.row(), (int)params_.stride, (int)extent.row(), (int)extent.column()); + if (ScatterD) { byte_pointer_ = reinterpret_cast(pointer) + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; @@ -314,10 +334,10 @@ class PredicatedTileIteratorReducedVec { /// Stores a fragment to memory CUTLASS_DEVICE - void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const + void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const { uint8_t* byte_pointer = byte_pointer_; - AccessType const* frag_ptr = reinterpret_cast(&frag); + AccessType* frag_ptr = reinterpret_cast(&frag); CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { @@ -342,7 +362,7 @@ class PredicatedTileIteratorReducedVec { byte_pointer + byte_offset + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); } - +#if 1 CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { bool guard = row_guard && mask_.predicates[column]; @@ -353,19 +373,45 @@ class PredicatedTileIteratorReducedVec { frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; } } else { - cutlass::arch::global_store( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; + + if (guard) { + //printf("gmem column id = %d guard = %d \n", (int)(thread_start_column_ + ThreadMap::Delta::kColumn * column), (int) guard); + params_.user_param.red_op_.init_key(&(*frag_ptr)[frag_idx + column], thread_start_column_ + ThreadMap::Delta::kColumn * column); + params_.user_param.red_op_(thread_start_column_ + ThreadMap::Delta::kColumn * column, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_idx + column]); + } } } +#if 1 + if (row_guard ) { + // printf("blockIdx.x = %d threadIdx.x = %d sizeof(Element) = %d params_.increment_row = %d params_.increment_group = %d params_.increment_cluster = %d frag_row_idx = %d extent_row_ = %d rowid = %d extent.column() = %d\n", + // (int)blockIdx.x, (int)threadIdx.x, (int)sizeof(Element), (int)params_.increment_row, (int)params_.increment_group, (int)params_.increment_cluster, + // (int)frag_row_idx, (int)extent_row_, (int)(row_offset + thread_start_row_), (int)extent_column_); + + while (atomicCAS(params_.user_param.mutexes_ + row_offset + thread_start_row_, 0, 1) == 1) + ; + __threadfence(); + params_.user_param.red_op_(row_offset + thread_start_row_, (Element*)&memory_pointer[0], (*frag_ptr)[frag_row_idx * ThreadMap::Iterations::kColumn]); + // cutlass::arch::global_store( + // frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn], + // (void*)&memory_pointer[0], + // row_guard); + __threadfence(); + atomicCAS(params_.user_param.mutexes_ + row_offset + thread_start_row_, 1, 0); + } +#endif +#endif if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { byte_pointer += params_.increment_row; } + if (!ScatterD) { + byte_pointer += params_.increment_row; + } } } - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } } if (cluster + 1 < ThreadMap::Iterations::kCluster) { @@ -376,7 +422,7 @@ class PredicatedTileIteratorReducedVec { /// Stores a fragment to memory CUTLASS_DEVICE - void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } + void store(Fragment& frag) const { store_with_byte_offset(frag, 0); } /// Loads a fragment from memory CUTLASS_DEVICE @@ -545,6 +591,7 @@ class PredicatedTileIteratorReducedVec { (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; ++state_[2]; byte_pointer_ += params_.advance_cluster; diff --git a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh index 38ba1137ac..643c7986d5 100644 --- a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh @@ -59,6 +59,9 @@ struct FixConnectivitiesRedOp { value_idx* colors; value_idx m; + // default constructor for cutlass + DI FixConnectivitiesRedOp() : colors(0), m(0) { } + FixConnectivitiesRedOp(value_idx* colors_, value_idx m_) : colors(colors_), m(m_){}; typedef typename raft::KeyValuePair KVP; @@ -80,8 +83,8 @@ struct FixConnectivitiesRedOp { return b; } - DI void init(value_t* out, value_t maxVal) { *out = maxVal; } - DI void init(KVP* out, value_t maxVal) + DI void init(value_t* out, value_t maxVal) const { *out = maxVal; } + DI void init(KVP* out, value_t maxVal) const { out->key = -1; out->value = maxVal; From 6663d4c754d07ba437b75758abcf781639c062d3 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 28 Dec 2022 00:11:34 +0530 Subject: [PATCH 04/48] improve perf by warp reduce and reduce register by adjusting gemm block/warp shape size, this now touches the perf of fusedL2NN simt kernel --- .../raft/distance/detail/fused_l2_nn.cuh | 22 ++++++-- .../raft/distance/detail/fused_l2_nn_gemm.h | 4 +- .../predicated_tile_iterator_reduced_vec.h | 52 +++++++++---------- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 396d100c1c..63a0b43972 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -71,11 +71,10 @@ struct MinAndDistanceReduceOpImpl { out->value = maxVal; } - DI void init_key(DataT *out, LabelT idx) const { return; } - DI void init_key(KVP *out, LabelT idx) const + DI void init_key(DataT &out, LabelT idx) const { return; } + DI void init_key(KVP &out, LabelT idx) const { - out->key = idx; - //out->value = maxVal; + out.key = idx; } }; @@ -280,6 +279,7 @@ struct kvp_fin_op { typedef typename raft::KeyValuePair KVP; __host__ __device__ kvp_fin_op() noexcept {}; +#if 0 // functor signature. __host__ __device__ void operator()(KVP &a, AccType d_val, Index idx) const { @@ -291,6 +291,19 @@ struct kvp_fin_op { { return; } +#else + // functor signature. + __host__ __device__ KVP operator()(KVP a, KVP b) const + { + // a.value = d_val; + // a.key = idx; + return a.value < b.value ? a : b; + } + __host__ __device__ AccType operator()(AccType a, AccType b) const + { + return a < b ? a : b; + } +#endif }; template = 8) { using L2Op = L2ExpandedOp; using final_op_kvp_ = kvp_fin_op; diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h index 9482062e29..555f0ac786 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -60,10 +60,10 @@ struct FusedL2NNGemm { /// Threadblock-level tile size (concept: GemmShape) using ThreadblockShape = - cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16 + cutlass::gemm::GemmShape<32, 64, 16>; // <- threadblock tile M = 128, N = 128, K = 16 /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute - using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 + using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op using InstructionShape = diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index 16696b2cf2..6fa2f252a9 100755 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -41,6 +41,10 @@ This file contains a customized version of PredicatedTileIterator from CUTLASS 2 #include #include #include +#include +#include + +namespace cg = cooperative_groups; //////////////////////////////////////////////////////////////////////////////// @@ -254,9 +258,6 @@ class PredicatedTileIteratorReducedVec { byte_pointer_ = reinterpret_cast(pointer) + LongIndex(thread_offset.row()) * LongIndex(params_.stride); - // printf("blockId = %d threadId = %d thread_offset_row = %d stride = %d extent.row() = %d extent.column() = %d\n", - // (int)blockIdx.x, (int)threadIdx.x, (int)thread_offset.row(), (int)params_.stride, (int)extent.row(), (int)extent.column()); - if (ScatterD) { byte_pointer_ = reinterpret_cast(pointer) + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; @@ -339,6 +340,9 @@ class PredicatedTileIteratorReducedVec { uint8_t* byte_pointer = byte_pointer_; AccessType* frag_ptr = reinterpret_cast(&frag); + cg::thread_block cta = cg::this_thread_block(); + cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); + CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { CUTLASS_PRAGMA_UNROLL @@ -362,7 +366,7 @@ class PredicatedTileIteratorReducedVec { byte_pointer + byte_offset + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); } -#if 1 + CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { bool guard = row_guard && mask_.predicates[column]; @@ -373,34 +377,30 @@ class PredicatedTileIteratorReducedVec { frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; } } else { - int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; - + const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; if (guard) { - //printf("gmem column id = %d guard = %d \n", (int)(thread_start_column_ + ThreadMap::Delta::kColumn * column), (int) guard); - params_.user_param.red_op_.init_key(&(*frag_ptr)[frag_idx + column], thread_start_column_ + ThreadMap::Delta::kColumn * column); + params_.user_param.red_op_.init_key((*frag_ptr)[frag_idx + column], thread_start_column_ + ThreadMap::Delta::kColumn * column); params_.user_param.red_op_(thread_start_column_ + ThreadMap::Delta::kColumn * column, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_idx + column]); } } } -#if 1 - if (row_guard ) { - // printf("blockIdx.x = %d threadIdx.x = %d sizeof(Element) = %d params_.increment_row = %d params_.increment_group = %d params_.increment_cluster = %d frag_row_idx = %d extent_row_ = %d rowid = %d extent.column() = %d\n", - // (int)blockIdx.x, (int)threadIdx.x, (int)sizeof(Element), (int)params_.increment_row, (int)params_.increment_group, (int)params_.increment_cluster, - // (int)frag_row_idx, (int)extent_row_, (int)(row_offset + thread_start_row_), (int)extent_column_); - - while (atomicCAS(params_.user_param.mutexes_ + row_offset + thread_start_row_, 0, 1) == 1) - ; - __threadfence(); - params_.user_param.red_op_(row_offset + thread_start_row_, (Element*)&memory_pointer[0], (*frag_ptr)[frag_row_idx * ThreadMap::Iterations::kColumn]); - // cutlass::arch::global_store( - // frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn], - // (void*)&memory_pointer[0], - // row_guard); - __threadfence(); - atomicCAS(params_.user_param.mutexes_ + row_offset + thread_start_row_, 1, 0); + + auto subTile = cg::binary_partition(tile32, row_guard && mask_.predicates[0]); + if (row_guard && mask_.predicates[0] ) { + + (*frag_ptr)[frag_row_idx * ThreadMap::Iterations::kColumn] = cg::reduce(subTile, (*frag_ptr)[frag_row_idx * ThreadMap::Iterations::kColumn], params_.user_param.final_op_); + + if (subTile.thread_rank() == 0) { + + while (atomicCAS(params_.user_param.mutexes_ + row_offset + thread_start_row_, 0, 1) == 1); + __threadfence(); + params_.user_param.red_op_(row_offset + thread_start_row_, + (Element*)&memory_pointer[0], + (*frag_ptr)[frag_row_idx * ThreadMap::Iterations::kColumn]); + __threadfence(); + atomicCAS(params_.user_param.mutexes_ + row_offset + thread_start_row_, 1, 0); + } } -#endif -#endif if (row + 1 < ThreadMap::Iterations::kRow) { if (!ScatterD) { From ca88ad1fdb96867a0f4a69f19edcc3ac8f5d59a5 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 30 Dec 2022 19:15:36 +0530 Subject: [PATCH 05/48] improve the perf of fusedL2NN cutlass kernel by reducing atomic locks from per row to multi-rows. now this kernel is 1.3x to 1.8x faster as k value increases perf gets better than fusedL2NN simt kernel --- .../raft/distance/detail/fused_l2_nn.cuh | 2 - .../raft/distance/detail/fused_l2_nn_gemm.h | 2 +- .../predicated_tile_iterator_reduced_vec.h | 67 ++++++++++++------- 3 files changed, 44 insertions(+), 27 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 63a0b43972..467cec44ce 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -295,8 +295,6 @@ struct kvp_fin_op { // functor signature. __host__ __device__ KVP operator()(KVP a, KVP b) const { - // a.value = d_val; - // a.key = idx; return a.value < b.value ? a : b; } __host__ __device__ AccType operator()(AccType a, AccType b) const diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h index 555f0ac786..72080598eb 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -81,7 +81,7 @@ struct FusedL2NNGemm { // This code section describes how threadblocks are scheduled on GPU /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>; /// data layout for final output matrix. // we keep this same layout even for column major inputs diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index 6fa2f252a9..4bb2a6f85f 100755 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -94,18 +94,12 @@ class PredicatedTileIteratorReducedVec { static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); /// Fragment object - // using Fragment = Array; - using Fragment = Array; /// Memory access size - //using AccessType = AlignedArray; using AccessType = AlignedArray; // @@ -131,7 +125,6 @@ class PredicatedTileIteratorReducedVec { CUTLASS_HOST_DEVICE Params(Layout const& layout, EpilogueOpParams const& user_param_) : PredicatedTileIteratorParams( - //layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, int(sizeof(AccessType)) / kElementsPerAccess, make_OutputTileThreadMapDesc()), user_param(user_param_) @@ -224,6 +217,7 @@ class PredicatedTileIteratorReducedVec { // Methods // + /// Constructor CUTLASS_DEVICE PredicatedTileIteratorReducedVec(Params const& params, @@ -242,6 +236,7 @@ class PredicatedTileIteratorReducedVec { thread_start_row_ = thread_offset.row(); thread_start_column_ = thread_offset.column(); + // Initialize predicates CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { @@ -367,6 +362,7 @@ class PredicatedTileIteratorReducedVec { LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); } + const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { bool guard = row_guard && mask_.predicates[column]; @@ -377,36 +373,59 @@ class PredicatedTileIteratorReducedVec { frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; } } else { - const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; if (guard) { - params_.user_param.red_op_.init_key((*frag_ptr)[frag_idx + column], thread_start_column_ + ThreadMap::Delta::kColumn * column); - params_.user_param.red_op_(thread_start_column_ + ThreadMap::Delta::kColumn * column, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_idx + column]); + const auto key_id = thread_start_column_ + ThreadMap::Delta::kColumn * column; + const int frag_col_idx = frag_idx + column; + params_.user_param.red_op_.init_key((*frag_ptr)[frag_col_idx], key_id); + params_.user_param.red_op_(key_id, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_col_idx]); } } } auto subTile = cg::binary_partition(tile32, row_guard && mask_.predicates[0]); - if (row_guard && mask_.predicates[0] ) { - (*frag_ptr)[frag_row_idx * ThreadMap::Iterations::kColumn] = cg::reduce(subTile, (*frag_ptr)[frag_row_idx * ThreadMap::Iterations::kColumn], params_.user_param.final_op_); + if (row_guard && mask_.predicates[0]) { + (*frag_ptr)[frag_idx] = cg::reduce(subTile, (*frag_ptr)[frag_idx], params_.user_param.final_op_); + } + if (tile32.thread_rank() > 0) { + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + } + + if (tile32.thread_rank() == 0 && thread_start_row_ < extent_row_) { - if (subTile.thread_rank() == 0) { + int *row_mutex = params_.user_param.mutexes_ + thread_start_row_; + while (atomicCAS(row_mutex, 0, 1) == 1); + __threadfence(); - while (atomicCAS(params_.user_param.mutexes_ + row_offset + thread_start_row_, 0, 1) == 1); - __threadfence(); - params_.user_param.red_op_(row_offset + thread_start_row_, - (Element*)&memory_pointer[0], - (*frag_ptr)[frag_row_idx * ThreadMap::Iterations::kColumn]); - __threadfence(); - atomicCAS(params_.user_param.mutexes_ + row_offset + thread_start_row_, 1, 0); + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + const int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; + if (row_guard && mask_.predicates[0]) { + params_.user_param.red_op_(row_offset + thread_start_row_, + (Element*)&memory_pointer[0], + (*frag_ptr)[frag_idx]); } - } - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { - byte_pointer += params_.increment_row; + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } } } + __threadfence(); + atomicCAS(row_mutex, 1, 0); } if (group + 1 < ThreadMap::Iterations::kGroup) { From 4e4e6ffa6c883442b9d957a6914b12701be8e1d0 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 30 Dec 2022 21:13:03 +0530 Subject: [PATCH 06/48] add the custom gemm fused epilogue header required for passing params to predicated tile iterator --- .../fusedL2NN_gemm_with_fused_epilogue.h | 782 ++++++++++++++++++ 1 file changed, 782 insertions(+) create mode 100755 cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h diff --git a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h new file mode 100755 index 0000000000..b93dbbbfa5 --- /dev/null +++ b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h @@ -0,0 +1,782 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Gemm kernel with fused reduction operation. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct FusedL2NNWithFusedEpilogue { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value + ); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + typename EpilogueOutputOp::Params epilogue; + + void const * ptr_A; + void const * ptr_B; + void const * ptr_C; + void * ptr_D; + + void * ptr_Vector; + void * ptr_Tensor; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldd; + typename LayoutC::Stride::Index ldr; + typename LayoutC::Stride::Index ldt; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const * ptr_A, + void const * ptr_B, + void const * ptr_C, + void * ptr_D, + void * ptr_Vector, + void * ptr_Tensor, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + int64_t batch_stride_Vector, + int64_t batch_stride_Tensor, + typename LayoutA::Stride::Index lda, + typename LayoutB::Stride::Index ldb, + typename LayoutC::Stride::Index ldc, + typename LayoutC::Stride::Index ldd, + typename LayoutC::Stride::Index ldr, + typename LayoutC::Stride::Index ldt + ): + mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D), + batch_stride_Vector(batch_stride_Vector), + batch_stride_Tensor(batch_stride_Tensor), + lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt) + { + CUTLASS_TRACE_HOST("FusedL2NNWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + CUTLASS_TRACE_HOST(" ldt: " << this->ldt); + } + + /// Returns arguments for the transposed problem + Arguments transposed_problem() const { + Arguments args(*this); + + std::swap(args.problem_size.m(), args.problem_size.n()); + std::swap(args.ptr_A, args.ptr_B); + std::swap(args.lda, args.ldb); + std::swap(args.batch_stride_A, args.batch_stride_B); + + return args; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::TensorTileIterator::Params params_Tensor; + + typename EpilogueOutputOp::Params output_op; + + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + void * ptr_C; + void * ptr_D; + + void * ptr_Vector; + typename LayoutC::Stride::Index ldr; + + void * ptr_Tensor; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_C; + int64_t batch_stride_D; + int64_t batch_stride_Vector; + int64_t batch_stride_Tensor; + + int *semaphore; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + params_A(0), + params_B(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_Vector(nullptr), + ldr(0), + ptr_Tensor(nullptr), + batch_stride_A(0), + batch_stride_B(0), + batch_stride_C(0), + batch_stride_D(0), + batch_stride_Vector(0), + batch_stride_Tensor(0), + semaphore(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + int gemm_k_size, + void *workspace = nullptr + ): + problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + params_D(args.ldd), + params_Tensor(args.ldt, args.epilogue), + output_op(args.epilogue), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(gemm_k_size), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr), + ptr_Tensor(args.ptr_Tensor), + + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + batch_stride_Vector(args.batch_stride_Vector), + batch_stride_Tensor(args.batch_stride_Tensor), + + semaphore(static_cast(workspace)) { + + CUTLASS_TRACE_HOST("FusedL2NNWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + CUTLASS_TRACE_HOST(" ldt: " << args.ldt); + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr) { + + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_D = args.ptr_D; + + ptr_Vector = args.ptr_Vector; + ldr = args.ldr; + ptr_Tensor = args.ptr_Tensor; + + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + batch_stride_D = args.batch_stride_D; + batch_stride_Vector = args.batch_stride_Vector; + batch_stride_Tensor = args.batch_stride_Tensor; + + output_op = args.epilogue; + + semaphore = static_cast(workspace); + + CUTLASS_TRACE_HOST("FusedL2NNWithFusedEpilogue::Params::update()"); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ldr: " << this->ldr); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + FusedL2NNWithFusedEpilogue() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("FusedL2NNWithFusedEpilogue::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const &args, + cutlass::gemm::GemmCoord const &grid_tiled_shape) { + + return 0; + } + + #define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + + #if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } + #endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_D = static_cast(params.ptr_D); + typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + // + // Fetch pointers based on mode. + // + + // + // Special path when split-K not enabled. + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) { + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + // Only the final block outputs Tensor + ptr_Tensor, + params.problem_size.mn(), + thread_idx, + threadblock_offset); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + ptr_Vector, + iterator_D, + accumulators, + iterator_C, + tensor_iterator, + params.problem_size.mn(), + threadblock_offset); + + return; + } + + // + // Slower path when split-K or batching is needed + // + + + #if SPLIT_K_ENABLED + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + if (ptr_Tensor) { + ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; + } + if (ptr_Vector) { + ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; + } + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; + if (ptr_Tensor) { + ptr_Tensor = static_cast(params.ptr_Tensor)[threadblock_tile_offset.k()]; + } + if (ptr_Vector) { + ptr_Vector = static_cast(params.ptr_Vector)[threadblock_tile_offset.k()]; + } + } + #endif + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + ptr_C, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + ptr_D, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + params.params_Tensor, + // Only the final block outputs Tensor + ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) + ? nullptr + : ptr_Tensor, + params.problem_size.mn(), + thread_idx, + threadblock_offset); + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + #if SPLIT_K_ENABLED + // Wait on the semaphore - this latency may have been covered by iterator construction + if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + } + #endif + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + // Only the final block uses Vector + ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) + ? nullptr + : ptr_Vector, + iterator_D, + accumulators, + iterator_C, + tensor_iterator, + params.problem_size.mn(), + threadblock_offset); + + // + // Release the semaphore + // + + #if SPLIT_K_ENABLED + if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + #endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// From 8af3ae799e1f3d53252c7cf7d50cd69805f1f5f2 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 30 Dec 2022 22:00:38 +0530 Subject: [PATCH 07/48] rename final_op as cg_reduce_op, cleanup --- .../raft/distance/detail/fused_l2_nn.cuh | 32 ++++++------------- .../fused_l2_nn_epilogue_elementwise.cuh | 15 ++++----- .../raft/distance/detail/fused_l2_nn_gemm.h | 2 -- .../predicated_tile_iterator_reduced_vec.h | 10 ++++-- 4 files changed, 24 insertions(+), 35 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 467cec44ce..ea36cf5bdb 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -272,26 +272,14 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, obj.run(); } -// final op functor for FusedL2NN used in its cutlass version -// to convert the distance value & key(loc id) into key-value pair +// cg::reduce functor for FusedL2NN used in its cutlass version +// to output the min distance value & key(loc id). template -struct kvp_fin_op { +struct kvp_cg_reduce_op { typedef typename raft::KeyValuePair KVP; - __host__ __device__ kvp_fin_op() noexcept {}; -#if 0 - // functor signature. - __host__ __device__ void operator()(KVP &a, AccType d_val, Index idx) const - { - a.value = d_val; - a.key = idx; - return; - } - __host__ __device__ void operator()(AccType &a, AccType d_val, Index idx) const - { - return; - } -#else + __host__ __device__ kvp_cg_reduce_op() noexcept {}; + // functor signature. __host__ __device__ KVP operator()(KVP a, KVP b) const { @@ -301,7 +289,7 @@ struct kvp_fin_op { { return a < b ? a : b; } -#endif + }; template = 8) { using L2Op = L2ExpandedOp; - using final_op_kvp_ = kvp_fin_op; - final_op_kvp_ fin_op_kvp; + using kvp_cg_reduce_op_ = kvp_cg_reduce_op; + kvp_cg_reduce_op_ cg_reduce_op; L2Op L2_dist_op(sqrt); IdxT lda, ldb, ldd; lda = k, ldb = k, ldd = n; cutlassFusedL2NNKernel(x, y, xn, yn, m, n, k, - lda, ldb, ldd, min, workspace, fin_op_kvp, L2_dist_op, + kvp_cg_reduce_op_, L2Op, ReduceOpT, KVPReduceOpT>(x, y, xn, yn, m, n, k, + lda, ldb, ldd, min, workspace, cg_reduce_op, L2_dist_op, redOp, pairRedOp, stream); } else { constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh index a977ba5765..c31b6709a3 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh @@ -50,7 +50,7 @@ template class FusedL2NNEpilogueElementwise { @@ -65,7 +65,7 @@ class FusedL2NNEpilogueElementwise { static int const kCount = kElementsPerAccess; using DistanceOp = DistanceOp_; - using FinalOp = FinalOp_; + using CGReduceOp = CGReduceOp_; using FragmentAccumulator = Array; using FragmentCompute = Array; @@ -85,7 +85,7 @@ class FusedL2NNEpilogueElementwise { /// Host-constructable parameters structure struct Params { - FinalOp_ final_op_; + CGReduceOp_ cg_reduce_op; DistanceOp_ dist_op_; KVPReduceOpT_ pair_redop_; ReduceOpT_ red_op_; @@ -94,10 +94,10 @@ class FusedL2NNEpilogueElementwise { // Methods // CUTLASS_HOST_DEVICE - Params(DistanceOp_ dist_op, FinalOp final_op, + Params(DistanceOp_ dist_op, CGReduceOp cg_reduce_op, ReduceOpT_ red_op, KVPReduceOpT_ pair_redop, int *mutexes) : - final_op_(final_op), dist_op_(dist_op), pair_redop_(pair_redop), + cg_reduce_op(cg_reduce_op), dist_op_(dist_op), pair_redop_(pair_redop), red_op_(red_op), mutexes_(mutexes) {} CUTLASS_HOST_DEVICE @@ -108,7 +108,6 @@ class FusedL2NNEpilogueElementwise { // // Data members // - FinalOp_ final_op; DistanceOp_ elementwise_op; KVPReduceOpT_ pair_redop; ReduceOpT_ red_op; @@ -121,7 +120,7 @@ class FusedL2NNEpilogueElementwise { /// Constructor from Params CUTLASS_HOST_DEVICE FusedL2NNEpilogueElementwise(Params const& params) - : final_op(params.final_op_), elementwise_op(params.dist_op_), + : elementwise_op(params.dist_op_), pair_redop(params.pair_redop_), red_op(params.red_op_) { } @@ -155,9 +154,7 @@ class FusedL2NNEpilogueElementwise { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerAccess; ++i) { - //result_Z[i] = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); ElementCompute res_Z = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); - //frag_T[i] = final_op(result_Z[i], 0); red_op.init(&frag_T[i], res_Z); } diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h index 72080598eb..121ba70eb6 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -131,7 +131,6 @@ struct FusedL2NNGemm { GemmBase::Epilogue::kElementsPerAccess>::Epilogue; // Compose the GEMM kernel - //using GemmKernel = GemmWithFusedEpilogue; using GemmKernel = FusedL2NNWithFusedEpilogue; }; @@ -231,7 +230,6 @@ struct FusedL2NNGemm::Epilogue; // Compose the GEMM kernel - //using GemmKernel = GemmWithFusedEpilogue; using GemmKernel = FusedL2NNWithFusedEpilogue; }; diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index 4bb2a6f85f..d6118aab81 100755 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -236,7 +236,9 @@ class PredicatedTileIteratorReducedVec { thread_start_row_ = thread_offset.row(); thread_start_column_ = thread_offset.column(); - + // if (blockIdx.x == 0 && blockIdx.y == 0) { + // printf("constructor tid = %d thread_start_row_ = %d thread_start_column_ = %d\n ", (int)threadIdx.x, (int)thread_start_row_, (int)thread_start_column_); + // } // Initialize predicates CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { @@ -338,6 +340,10 @@ class PredicatedTileIteratorReducedVec { cg::thread_block cta = cg::this_thread_block(); cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); + // if (blockIdx.x == 0 && blockIdx.y == 0) { + // printf("tid = %d thread_start_row_ = %d thread_start_column_ = %d\n ", (int)threadIdx.x, (int)thread_start_row_, (int)thread_start_column_); + // } + CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { CUTLASS_PRAGMA_UNROLL @@ -385,7 +391,7 @@ class PredicatedTileIteratorReducedVec { auto subTile = cg::binary_partition(tile32, row_guard && mask_.predicates[0]); if (row_guard && mask_.predicates[0]) { - (*frag_ptr)[frag_idx] = cg::reduce(subTile, (*frag_ptr)[frag_idx], params_.user_param.final_op_); + (*frag_ptr)[frag_idx] = cg::reduce(subTile, (*frag_ptr)[frag_idx], params_.user_param.cg_reduce_op); } if (tile32.thread_rank() > 0) { if (row + 1 < ThreadMap::Iterations::kRow) { From d2c3833d5d724f28e49b782a119a892cdac69edd Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 3 Jan 2023 21:17:37 +0530 Subject: [PATCH 08/48] add whole block multi-row single lock impl which performs 5-7% faster than per warp multi row lock, cleanup and doc update --- .../fusedL2NN_gemm_with_fused_epilogue.h | 1 + .../detail/fused_l2_nn_cutlass_base.cuh | 25 ++----- .../fused_l2_nn_epilogue_elementwise.cuh | 7 +- .../raft/distance/detail/fused_l2_nn_gemm.h | 19 ++--- .../predicated_tile_iterator_reduced_vec.h | 69 ++++++++++++++----- 5 files changed, 72 insertions(+), 49 deletions(-) diff --git a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h index b93dbbbfa5..e695f05ad9 100755 --- a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h +++ b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h @@ -290,6 +290,7 @@ struct FusedL2NNWithFusedEpilogue { params_B(args.ldb), params_C(args.ldc), params_D(args.ldd), + // Here we additional pass user args via args.epilogue params_Tensor(args.ldt, args.epilogue), output_op(args.epilogue), mode(args.mode), diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh index da2526b796..31175d12bd 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,8 +19,6 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" -#if (__CUDACC_VER_MAJOR__ < 12) - // We define CUTLASS_NAMESPACE in case // RAFT cmake is not used #ifndef CUTLASS_NAMESPACE @@ -83,9 +81,6 @@ void cutlassFusedL2NNKernel(const DataT* x, KVPReduceOpT pairRedOp, cudaStream_t stream) { - // static_assert(!(std::is_same::value), - // "OutType bool is not supported use uint8_t instead"); - using EpilogueOutputOp = cutlass::epilogue::thread::FusedL2NNEpilogueElementwise; - if constexpr (isRowMajor) { - a = y; - b = x; - gemm_lda = ldb; - gemm_ldb = lda; - } else { - problem_size = cutlass::gemm::GemmCoord(m, n, k); - a = x; - b = y; - gemm_lda = lda; - gemm_ldb = ldb; - } + a = y; + b = x; + gemm_lda = ldb; + gemm_ldb = lda; typename cutlassDist::Arguments arguments{ mode, problem_size, batch_count, epilog_op_param, a, b, @@ -183,5 +170,5 @@ void cutlassFusedL2NNKernel(const DataT* x, }; // namespace detail }; // namespace distance }; // namespace raft -#endif // (__CUDACC_VER_MAJOR__ < 12) + #pragma GCC diagnostic pop diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh index c31b6709a3..7663012195 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh @@ -89,14 +89,14 @@ class FusedL2NNEpilogueElementwise { DistanceOp_ dist_op_; KVPReduceOpT_ pair_redop_; ReduceOpT_ red_op_; - int *mutexes_; + volatile int *mutexes_; // // Methods // CUTLASS_HOST_DEVICE Params(DistanceOp_ dist_op, CGReduceOp cg_reduce_op, ReduceOpT_ red_op, KVPReduceOpT_ pair_redop, - int *mutexes) : + volatile int *mutexes) : cg_reduce_op(cg_reduce_op), dist_op_(dist_op), pair_redop_(pair_redop), red_op_(red_op), mutexes_(mutexes) {} @@ -150,7 +150,6 @@ class FusedL2NNEpilogueElementwise { FragmentCompute tmp_C = NumericArrayConverter()(frag_C); FragmentCompute result_Z; - //FragmentT result_T; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerAccess; ++i) { @@ -158,8 +157,6 @@ class FusedL2NNEpilogueElementwise { red_op.init(&frag_T[i], res_Z); } - // NumericArrayConverter convert_t; - // frag_T = convert_t(result_T); } /// Applies the operation when is_source_needed() is false diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h index 121ba70eb6..aa20f348c2 100755 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -59,15 +59,17 @@ struct FusedL2NNGemm { // This struct is specialized for fp32/3xTF32 /// Threadblock-level tile size (concept: GemmShape) + // <- threadblock tile M = 32, N = 64, K = 16 using ThreadblockShape = - cutlass::gemm::GemmShape<32, 64, 16>; // <- threadblock tile M = 128, N = 128, K = 16 + cutlass::gemm::GemmShape<32, 64, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute - using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; // <- warp tile M = 64, N = 64, K = 16 + // <- warp tile M = 64, N = 64, K = 16 + using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op - using InstructionShape = - cutlass::gemm::GemmShape<16, 8, 4>; // <- MMA Op tile M = 16, N = 8, K = 8 + // <- MMA Op tile M = 16, N = 8, K = 4 + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; /// Operation performed by GEMM using Operator = cutlass::arch::OpMultiplyAddFastF32; @@ -158,13 +160,14 @@ struct FusedL2NNGemm { - // using Transform = cutlass::ComplexTransform::kNone; + // Threadblock-level tile size (concept: GemmShape) - using ThreadblockShape = - cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 64, N = 64, K = 16 + // <- threadblock tile M = 64, N = 64, K = 16 + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute - using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 32, N = 32, K = 16 + // <- warp tile M = 32, N = 32, K = 16 + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index d6118aab81..b6808f915d 100755 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -216,8 +216,6 @@ class PredicatedTileIteratorReducedVec { // // Methods // - - /// Constructor CUTLASS_DEVICE PredicatedTileIteratorReducedVec(Params const& params, @@ -236,9 +234,6 @@ class PredicatedTileIteratorReducedVec { thread_start_row_ = thread_offset.row(); thread_start_column_ = thread_offset.column(); - // if (blockIdx.x == 0 && blockIdx.y == 0) { - // printf("constructor tid = %d thread_start_row_ = %d thread_start_column_ = %d\n ", (int)threadIdx.x, (int)thread_start_row_, (int)thread_start_column_); - // } // Initialize predicates CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { @@ -330,7 +325,7 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } - /// Stores a fragment to memory + /// Performs reduction and Stores a reduced output to memory CUTLASS_DEVICE void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const { @@ -340,10 +335,6 @@ class PredicatedTileIteratorReducedVec { cg::thread_block cta = cg::this_thread_block(); cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); - // if (blockIdx.x == 0 && blockIdx.y == 0) { - // printf("tid = %d thread_start_row_ = %d thread_start_column_ = %d\n ", (int)threadIdx.x, (int)thread_start_row_, (int)thread_start_column_); - // } - CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { CUTLASS_PRAGMA_UNROLL @@ -387,12 +378,13 @@ class PredicatedTileIteratorReducedVec { } } } + bool col_guard = row_guard && mask_.predicates[0]; + auto subTile = cg::binary_partition(tile32, col_guard); - auto subTile = cg::binary_partition(tile32, row_guard && mask_.predicates[0]); - - if (row_guard && mask_.predicates[0]) { + if (col_guard) { (*frag_ptr)[frag_idx] = cg::reduce(subTile, (*frag_ptr)[frag_idx], params_.user_param.cg_reduce_op); } + if (tile32.thread_rank() > 0) { if (row + 1 < ThreadMap::Iterations::kRow) { if (!ScatterD) { @@ -401,11 +393,12 @@ class PredicatedTileIteratorReducedVec { } } } - +#if 0 + // single lock per warp for multiple rows if (tile32.thread_rank() == 0 && thread_start_row_ < extent_row_) { - int *row_mutex = params_.user_param.mutexes_ + thread_start_row_; - while (atomicCAS(row_mutex, 0, 1) == 1); + volatile int *row_mutex = params_.user_param.mutexes_ + thread_start_row_; + while (atomicCAS((int*)row_mutex, 0, 1) == 1); __threadfence(); CUTLASS_PRAGMA_UNROLL @@ -431,9 +424,51 @@ class PredicatedTileIteratorReducedVec { } } __threadfence(); - atomicCAS(row_mutex, 1, 0); + atomicCAS((int*)row_mutex, 1, 0); } +#else + // single lock per block for multiple rows + // this performs better for most of the cases than per warp lock. + if (threadIdx.x == 0 && thread_start_row_ < extent_row_) { + // acquire mutex lock. + volatile int *row_mutex = params_.user_param.mutexes_ + thread_start_row_; + while (atomicCAS((int*)row_mutex, 0, 1) == 1); + } + __syncthreads(); + if (tile32.thread_rank() == 0) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + const int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; + if (row_guard && mask_.predicates[0]) { + // reduction with the current gmem value. + params_.user_param.red_op_(row_offset + thread_start_row_, + (Element*)&memory_pointer[0], + (*frag_ptr)[frag_idx]); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + } + + __syncthreads(); + __threadfence(); + if (threadIdx.x == 0 && thread_start_row_ < extent_row_) { + // release mutex lock. + volatile int *row_mutex = params_.user_param.mutexes_ + thread_start_row_; + atomicCAS((int*)row_mutex, 1, 0); + } +#endif if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } From b38414f6b5dbfb1d0a9743b9b3e1f30f9bd508c9 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 4 Jan 2023 13:38:41 +0530 Subject: [PATCH 09/48] fix connected components reduction functor for working with cutlass --- .../sparse/neighbors/detail/connect_components.cuh | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh index 643c7986d5..8bc48332e6 100644 --- a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh @@ -65,7 +65,7 @@ struct FixConnectivitiesRedOp { FixConnectivitiesRedOp(value_idx* colors_, value_idx m_) : colors(colors_), m(m_){}; typedef typename raft::KeyValuePair KVP; - DI void operator()(value_idx rit, KVP* out, const KVP& other) + DI void operator()(value_idx rit, KVP* out, const KVP& other) const { if (rit < m && other.value < out->value && colors[rit] != colors[other.key]) { out->key = other.key; @@ -73,9 +73,7 @@ struct FixConnectivitiesRedOp { } } - DI KVP - - operator()(value_idx rit, const KVP& a, const KVP& b) + DI KVP operator()(value_idx rit, const KVP& a, const KVP& b) const { if (rit < m && a.value < b.value && colors[rit] != colors[a.key]) { return a; @@ -89,6 +87,12 @@ struct FixConnectivitiesRedOp { out->key = -1; out->value = maxVal; } + + DI void init_key(value_t &out, value_idx idx) const { return; } + DI void init_key(KVP &out, value_idx idx) const + { + out.key = idx; + } }; /** From c89f1c1d2ac636fc948318dd47e7c1094b830835 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 4 Jan 2023 15:52:09 +0530 Subject: [PATCH 10/48] fix clang format and copyright year --- .../fusedL2NN_gemm_with_fused_epilogue.h | 603 ++++++++---------- .../raft/distance/detail/fused_l2_nn.cuh | 92 +-- .../detail/fused_l2_nn_cutlass_base.cuh | 78 +-- .../distance/detail/fused_l2_nn_epilogue.cuh | 10 +- .../fused_l2_nn_epilogue_elementwise.cuh | 27 +- .../raft/distance/detail/fused_l2_nn_gemm.h | 30 +- .../predicated_tile_iterator_reduced_vec.h | 92 ++- .../neighbors/detail/connect_components.cuh | 11 +- 8 files changed, 446 insertions(+), 497 deletions(-) mode change 100755 => 100644 cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h mode change 100755 => 100644 cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh mode change 100755 => 100644 cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh mode change 100755 => 100644 cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh mode change 100755 => 100644 cpp/include/raft/distance/detail/fused_l2_nn_gemm.h mode change 100755 => 100644 cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h diff --git a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h old mode 100755 new mode 100644 index e695f05ad9..9cec7b96b4 --- a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h +++ b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,11 +34,11 @@ #pragma once +#include "cutlass/complex.h" #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/gemm/gemm.h" #include "cutlass/matrix_coord.h" -#include "cutlass/complex.h" #include "cutlass/semaphore.h" #include "cutlass/trace.h" @@ -51,50 +51,46 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -template < - typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate - typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_ ///! Threadblock swizzling function -> +template struct FusedL2NNWithFusedEpilogue { -public: - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; using ThreadblockSwizzle = ThreadblockSwizzle_; using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; + using LayoutA = typename Mma::IteratorA::Layout; using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; + using LayoutB = typename Mma::IteratorB::Layout; using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; static ComplexTransform const kTransformA = Mma::kTransformA; static ComplexTransform const kTransformB = Mma::kTransformB; - using Operator = typename Mma::Operator; + using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; + using OperatorClass = typename Mma::Operator::OperatorClass; using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; + using WarpShape = typename Mma::Operator::Shape; using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; + using ArchTag = typename Mma::ArchTag; - static int const kStages = Mma::kStages; + static int const kStages = Mma::kStages; static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; + using WarpCount = typename Mma::WarpCount; static int const kThreadCount = 32 * WarpCount::kCount; /// Split-K preserves splits that are 128b aligned - static int const kSplitKAlignment = const_max( - 128 / sizeof_bits::value, - 128 / sizeof_bits::value - ); + static int const kSplitKAlignment = + const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); // // Structures @@ -102,7 +98,6 @@ struct FusedL2NNWithFusedEpilogue { /// Argument structure struct Arguments { - // // Data members // @@ -113,13 +108,13 @@ struct FusedL2NNWithFusedEpilogue { typename EpilogueOutputOp::Params epilogue; - void const * ptr_A; - void const * ptr_B; - void const * ptr_C; - void * ptr_D; + void const* ptr_A; + void const* ptr_B; + void const* ptr_C; + void* ptr_D; - void * ptr_Vector; - void * ptr_Tensor; + void* ptr_Vector; + void* ptr_Tensor; int64_t batch_stride_A; int64_t batch_stride_B; @@ -138,63 +133,76 @@ struct FusedL2NNWithFusedEpilogue { // // Methods // - - Arguments(): - mode(GemmUniversalMode::kGemm), - batch_count(1), - ptr_A(nullptr), ptr_B(nullptr), ptr_C(nullptr), ptr_D(nullptr) { } + + Arguments() + : mode(GemmUniversalMode::kGemm), + batch_count(1), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr) + { + } /// constructs an arguments structure - Arguments( - GemmUniversalMode mode, - GemmCoord problem_size, - int batch_count, - typename EpilogueOutputOp::Params epilogue, - void const * ptr_A, - void const * ptr_B, - void const * ptr_C, - void * ptr_D, - void * ptr_Vector, - void * ptr_Tensor, - int64_t batch_stride_A, - int64_t batch_stride_B, - int64_t batch_stride_C, - int64_t batch_stride_D, - int64_t batch_stride_Vector, - int64_t batch_stride_Tensor, - typename LayoutA::Stride::Index lda, - typename LayoutB::Stride::Index ldb, - typename LayoutC::Stride::Index ldc, - typename LayoutC::Stride::Index ldd, - typename LayoutC::Stride::Index ldr, - typename LayoutC::Stride::Index ldt - ): - mode(mode), - problem_size(problem_size), - batch_count(batch_count), - epilogue(epilogue), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), - ptr_Vector(ptr_Vector), - ptr_Tensor(ptr_Tensor), - batch_stride_A(batch_stride_A), - batch_stride_B(batch_stride_B), - batch_stride_C(batch_stride_C), - batch_stride_D(batch_stride_D), - batch_stride_Vector(batch_stride_Vector), - batch_stride_Tensor(batch_stride_Tensor), - lda(lda), ldb(ldb), ldc(ldc), ldd(ldd), ldr(ldr), ldt(ldt) + Arguments(GemmUniversalMode mode, + GemmCoord problem_size, + int batch_count, + typename EpilogueOutputOp::Params epilogue, + void const* ptr_A, + void const* ptr_B, + void const* ptr_C, + void* ptr_D, + void* ptr_Vector, + void* ptr_Tensor, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D, + int64_t batch_stride_Vector, + int64_t batch_stride_Tensor, + typename LayoutA::Stride::Index lda, + typename LayoutB::Stride::Index ldb, + typename LayoutC::Stride::Index ldc, + typename LayoutC::Stride::Index ldd, + typename LayoutC::Stride::Index ldr, + typename LayoutC::Stride::Index ldt) + : mode(mode), + problem_size(problem_size), + batch_count(batch_count), + epilogue(epilogue), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_D(ptr_D), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D), + batch_stride_Vector(batch_stride_Vector), + batch_stride_Tensor(batch_stride_Tensor), + lda(lda), + ldb(ldb), + ldc(ldc), + ldd(ldd), + ldr(ldr), + ldt(ldt) { - CUTLASS_TRACE_HOST("FusedL2NNWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); - CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST( + "FusedL2NNWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void*)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void*)this->ptr_Tensor); CUTLASS_TRACE_HOST(" ldr: " << this->ldr); CUTLASS_TRACE_HOST(" ldt: " << this->ldt); } /// Returns arguments for the transposed problem - Arguments transposed_problem() const { + Arguments transposed_problem() const + { Arguments args(*this); - + std::swap(args.problem_size.m(), args.problem_size.n()); std::swap(args.ptr_A, args.ptr_B); std::swap(args.lda, args.ldb); @@ -210,7 +218,6 @@ struct FusedL2NNWithFusedEpilogue { /// Parameters structure struct Params { - cutlass::gemm::GemmCoord problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; int swizzle_log_tile; @@ -220,23 +227,22 @@ struct FusedL2NNWithFusedEpilogue { typename Epilogue::OutputTileIterator::Params params_C; typename Epilogue::OutputTileIterator::Params params_D; typename Epilogue::TensorTileIterator::Params params_Tensor; - - typename EpilogueOutputOp::Params output_op; + typename EpilogueOutputOp::Params output_op; GemmUniversalMode mode; int batch_count; int gemm_k_size; - void * ptr_A; - void * ptr_B; - void * ptr_C; - void * ptr_D; - - void * ptr_Vector; + void* ptr_A; + void* ptr_B; + void* ptr_C; + void* ptr_D; + + void* ptr_Vector; typename LayoutC::Stride::Index ldr; - void * ptr_Tensor; + void* ptr_Tensor; int64_t batch_stride_A; int64_t batch_stride_B; @@ -245,109 +251,108 @@ struct FusedL2NNWithFusedEpilogue { int64_t batch_stride_Vector; int64_t batch_stride_Tensor; - int *semaphore; + int* semaphore; // // Methods // CUTLASS_HOST_DEVICE - Params(): - swizzle_log_tile(0), - params_A(0), - params_B(0), - params_C(0), - params_D(0), - batch_count(0), - gemm_k_size(0), - mode(cutlass::gemm::GemmUniversalMode::kGemm), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - ptr_Vector(nullptr), - ldr(0), - ptr_Tensor(nullptr), - batch_stride_A(0), - batch_stride_B(0), - batch_stride_C(0), - batch_stride_D(0), - batch_stride_Vector(0), - batch_stride_Tensor(0), - semaphore(nullptr) { } + Params() + : swizzle_log_tile(0), + params_A(0), + params_B(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_Vector(nullptr), + ldr(0), + ptr_Tensor(nullptr), + batch_stride_A(0), + batch_stride_B(0), + batch_stride_C(0), + batch_stride_D(0), + batch_stride_Vector(0), + batch_stride_Tensor(0), + semaphore(nullptr) + { + } CUTLASS_HOST_DEVICE - Params( - Arguments const &args, - cutlass::gemm::GemmCoord const & grid_tiled_shape, - int gemm_k_size, - void *workspace = nullptr - ): - problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), - params_A(args.lda), - params_B(args.ldb), - params_C(args.ldc), - params_D(args.ldd), - // Here we additional pass user args via args.epilogue - params_Tensor(args.ldt, args.epilogue), - output_op(args.epilogue), - mode(args.mode), - batch_count(args.batch_count), - gemm_k_size(gemm_k_size), - ptr_A(const_cast(args.ptr_A)), - ptr_B(const_cast(args.ptr_B)), - ptr_C(const_cast(args.ptr_C)), - ptr_D(args.ptr_D), - ptr_Vector(args.ptr_Vector), - ldr(args.ldr), - ptr_Tensor(args.ptr_Tensor), - - batch_stride_A(args.batch_stride_A), - batch_stride_B(args.batch_stride_B), - batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D), - batch_stride_Vector(args.batch_stride_Vector), - batch_stride_Tensor(args.batch_stride_Tensor), - - semaphore(static_cast(workspace)) { - - CUTLASS_TRACE_HOST("FusedL2NNWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); - CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + Params(Arguments const& args, + cutlass::gemm::GemmCoord const& grid_tiled_shape, + int gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + params_D(args.ldd), + // Here we additional pass user args via args.epilogue + params_Tensor(args.ldt, args.epilogue), + output_op(args.epilogue), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(gemm_k_size), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_D(args.ptr_D), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr), + ptr_Tensor(args.ptr_Tensor), + + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D), + batch_stride_Vector(args.batch_stride_Vector), + batch_stride_Tensor(args.batch_stride_Tensor), + + semaphore(static_cast(workspace)) + { + CUTLASS_TRACE_HOST( + "FusedL2NNWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void*)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void*)this->ptr_Tensor); CUTLASS_TRACE_HOST(" ldr: " << this->ldr); CUTLASS_TRACE_HOST(" ldt: " << args.ldt); } CUTLASS_HOST_DEVICE - void update( - Arguments const &args, - void *workspace = nullptr) { - - ptr_A = const_cast(args.ptr_A); - ptr_B = const_cast(args.ptr_B); - ptr_C = const_cast(args.ptr_C); + void update(Arguments const& args, void* workspace = nullptr) + { + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); ptr_D = args.ptr_D; ptr_Vector = args.ptr_Vector; - ldr = args.ldr; + ldr = args.ldr; ptr_Tensor = args.ptr_Tensor; - batch_stride_A = args.batch_stride_A; - batch_stride_B = args.batch_stride_B; - batch_stride_C = args.batch_stride_C; - batch_stride_D = args.batch_stride_D; + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + batch_stride_D = args.batch_stride_D; batch_stride_Vector = args.batch_stride_Vector; batch_stride_Tensor = args.batch_stride_Tensor; output_op = args.epilogue; - semaphore = static_cast(workspace); + semaphore = static_cast(workspace); CUTLASS_TRACE_HOST("FusedL2NNWithFusedEpilogue::Params::update()"); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void *)this->ptr_Reduction); - CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void *)this->ptr_Tensor); + CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void*)this->ptr_Reduction); + CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void*)this->ptr_Tensor); CUTLASS_TRACE_HOST(" ldr: " << this->ldr); } }; @@ -358,19 +363,17 @@ struct FusedL2NNWithFusedEpilogue { typename Epilogue::SharedStorage epilogue; }; -public: - + public: // // Methods // CUTLASS_DEVICE - FusedL2NNWithFusedEpilogue() { } + FusedL2NNWithFusedEpilogue() {} /// Determines whether kernel satisfies alignment - static Status can_implement( - cutlass::gemm::GemmCoord const & problem_size) { - + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { CUTLASS_TRACE_HOST("FusedL2NNWithFusedEpilogue::can_implement()"); static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; @@ -385,8 +388,8 @@ struct FusedL2NNWithFusedEpilogue { isAMisaligned = problem_size.k() % kAlignmentA; } else if (platform::is_same::value) { isAMisaligned = problem_size.m() % kAlignmentA; - } else if (platform::is_same>::value - || platform::is_same>::value) { + } else if (platform::is_same>::value || + platform::is_same>::value) { isAMisaligned = problem_size.k() % kAlignmentA; } @@ -394,8 +397,8 @@ struct FusedL2NNWithFusedEpilogue { isBMisaligned = problem_size.n() % kAlignmentB; } else if (platform::is_same::value) { isBMisaligned = problem_size.k() % kAlignmentB; - } else if (platform::is_same>::value - || platform::is_same>::value) { + } else if (platform::is_same>::value || + platform::is_same>::value) { isBMisaligned = problem_size.k() % kAlignmentB; } @@ -403,8 +406,8 @@ struct FusedL2NNWithFusedEpilogue { isCMisaligned = problem_size.n() % kAlignmentC; } else if (platform::is_same::value) { isCMisaligned = problem_size.m() % kAlignmentC; - } else if (platform::is_same>::value - || platform::is_same>::value) { + } else if (platform::is_same>::value || + platform::is_same>::value) { isCMisaligned = problem_size.n() % kAlignmentC; } @@ -428,64 +431,57 @@ struct FusedL2NNWithFusedEpilogue { return Status::kSuccess; } - static Status can_implement(Arguments const &args) { - return can_implement(args.problem_size); - } - - static size_t get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { + static Status can_implement(Arguments const& args) { return can_implement(args.problem_size); } + static size_t get_extra_workspace_size(Arguments const& args, + cutlass::gemm::GemmCoord const& grid_tiled_shape) + { return 0; } - #define SPLIT_K_ENABLED 1 +#define SPLIT_K_ENABLED 1 /// Executes one GEMM CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - + void operator()(Params const& params, SharedStorage& shared_storage) + { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { return; } - int offset_k = 0; + int offset_k = 0; int problem_size_k = params.problem_size.k(); - ElementA *ptr_A = static_cast(params.ptr_A); - ElementB *ptr_B = static_cast(params.ptr_B); + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); - - #if SPLIT_K_ENABLED +#if SPLIT_K_ENABLED // // Fetch pointers based on mode. // - if (params.mode == GemmUniversalMode::kGemm || - params.mode == GemmUniversalMode::kGemmSplitKParallel) { - + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; } offset_k = threadblock_tile_offset.k() * params.gemm_k_size; - } - else if (params.mode == GemmUniversalMode::kBatched) { + } else if (params.mode == GemmUniversalMode::kBatched) { ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; } - else if (params.mode == GemmUniversalMode::kArray) { - ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; - ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; - } - #endif +#endif // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A{ @@ -493,28 +489,17 @@ struct FusedL2NNWithFusedEpilogue { offset_k, }; - cutlass::MatrixCoord tb_offset_B{ - offset_k, - threadblock_tile_offset.n() * Mma::Shape::kN - }; + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; // Compute position within threadblock int thread_idx = threadIdx.x; // Construct iterators to A and B operands typename Mma::IteratorA iterator_A( - params.params_A, - ptr_A, - {params.problem_size.m(), problem_size_k}, - thread_idx, - tb_offset_A); + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); typename Mma::IteratorB iterator_B( - params.params_B, - ptr_B, - {problem_size_k, params.problem_size.n()}, - thread_idx, - tb_offset_B); + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. @@ -537,12 +522,7 @@ struct FusedL2NNWithFusedEpilogue { int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add - mma( - gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, - accumulators); + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); // // Epilogue @@ -556,65 +536,49 @@ struct FusedL2NNWithFusedEpilogue { threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - //assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN - ); + // assume identity swizzle + MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + int block_idx = + threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - ElementC *ptr_C = static_cast(params.ptr_C); - ElementC *ptr_D = static_cast(params.ptr_D); - typename Epilogue::ElementTensor *ptr_Tensor = static_cast(params.ptr_Tensor); + ElementC* ptr_C = static_cast(params.ptr_C); + ElementC* ptr_D = static_cast(params.ptr_D); + typename Epilogue::ElementTensor* ptr_Tensor = + static_cast(params.ptr_Tensor); // Define the reduction output pointer and move to the appropriate place - typename Epilogue::ElementVector *ptr_Vector = - static_cast(params.ptr_Vector); + typename Epilogue::ElementVector* ptr_Vector = + static_cast(params.ptr_Vector); // // Fetch pointers based on mode. // - + // // Special path when split-K not enabled. - // + // if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) { - // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C( - params.params_C, - ptr_C, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); + params.params_C, ptr_C, params.problem_size.mn(), thread_idx, threadblock_offset); // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - ptr_D, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); + params.params_D, ptr_D, params.problem_size.mn(), thread_idx, threadblock_offset); // Additional tensor to load from - typename Epilogue::TensorTileIterator tensor_iterator( - params.params_Tensor, - // Only the final block outputs Tensor - ptr_Tensor, - params.problem_size.mn(), - thread_idx, - threadblock_offset); + typename Epilogue::TensorTileIterator tensor_iterator(params.params_Tensor, + // Only the final block outputs Tensor + ptr_Tensor, + params.problem_size.mn(), + thread_idx, + threadblock_offset); // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); // Move to appropriate location for this output tile if (ptr_Vector) { @@ -638,98 +602,72 @@ struct FusedL2NNWithFusedEpilogue { // Slower path when split-K or batching is needed // - - #if SPLIT_K_ENABLED +#if SPLIT_K_ENABLED // Construct the semaphore. Semaphore semaphore(params.semaphore + block_idx, thread_idx); if (params.mode == GemmUniversalMode::kGemm) { - // If performing a reduction via split-K, fetch the initial synchronization if (params.grid_tiled_shape.k() > 1) { - // Fetch the synchronization lock initially but do not block. semaphore.fetch(); // Indicate which position in a serial reduction the output operator is currently updating output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); } - } - else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { + } else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - } - else if (params.mode == GemmUniversalMode::kBatched) { + } else if (params.mode == GemmUniversalMode::kBatched) { ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; + if (ptr_Tensor) { ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; } + if (ptr_Vector) { ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; } + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; + ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; if (ptr_Tensor) { - ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; + ptr_Tensor = static_cast( + params.ptr_Tensor)[threadblock_tile_offset.k()]; } if (ptr_Vector) { - ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; + ptr_Vector = static_cast( + params.ptr_Vector)[threadblock_tile_offset.k()]; } } - else if (params.mode == GemmUniversalMode::kArray) { - ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; - ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; - if (ptr_Tensor) { - ptr_Tensor = static_cast(params.ptr_Tensor)[threadblock_tile_offset.k()]; - } - if (ptr_Vector) { - ptr_Vector = static_cast(params.ptr_Vector)[threadblock_tile_offset.k()]; - } - } - #endif +#endif // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C( - params.params_C, - ptr_C, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); + params.params_C, ptr_C, params.problem_size.mn(), thread_idx, threadblock_offset); // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D( - params.params_D, - ptr_D, - params.problem_size.mn(), - thread_idx, - threadblock_offset - ); + params.params_D, ptr_D, params.problem_size.mn(), thread_idx, threadblock_offset); // Additional tensor to load from typename Epilogue::TensorTileIterator tensor_iterator( - params.params_Tensor, - // Only the final block outputs Tensor - ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && - (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) - ? nullptr - : ptr_Tensor, - params.problem_size.mn(), - thread_idx, - threadblock_offset); + params.params_Tensor, + // Only the final block outputs Tensor + ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) + ? nullptr + : ptr_Tensor, + params.problem_size.mn(), + thread_idx, + threadblock_offset); // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, - lane_idx); + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - #if SPLIT_K_ENABLED +#if SPLIT_K_ENABLED // Wait on the semaphore - this latency may have been covered by iterator construction if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) { - iterator_C = iterator_D; - } + if (threadblock_tile_offset.k()) { iterator_C = iterator_D; } semaphore.wait(threadblock_tile_offset.k()); - } - #endif +#endif // Move to appropriate location for this output tile if (ptr_Vector) { @@ -741,8 +679,8 @@ struct FusedL2NNWithFusedEpilogue { // Only the final block uses Vector ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) - ? nullptr - : ptr_Vector, + ? nullptr + : ptr_Vector, iterator_D, accumulators, iterator_C, @@ -754,30 +692,27 @@ struct FusedL2NNWithFusedEpilogue { // Release the semaphore // - #if SPLIT_K_ENABLED - if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { - +#if SPLIT_K_ENABLED + if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { int lock = 0; if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - // The final threadblock resets the semaphore for subsequent grids. lock = 0; - } - else { + } else { // Otherwise, the semaphore is incremented lock = threadblock_tile_offset.k() + 1; } - + semaphore.release(lock); } - #endif +#endif } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace kernel -} // namespace gemm -} // namespace cutlass +} // namespace kernel +} // namespace gemm +} // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 9af6fcc15d..5311a26d19 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,12 @@ #include #include +#include +#include #include #include #include #include -#include -#include namespace raft { namespace distance { @@ -63,7 +63,6 @@ struct MinAndDistanceReduceOpImpl { if (other < *out) { *out = other; } } - DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } DI void init(KVP* out, DataT maxVal) const { @@ -71,11 +70,8 @@ struct MinAndDistanceReduceOpImpl { out->value = maxVal; } - DI void init_key(DataT &out, LabelT idx) const { return; } - DI void init_key(KVP &out, LabelT idx) const - { - out.key = idx; - } + DI void init_key(DataT& out, LabelT idx) const { return; } + DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } }; template @@ -275,22 +271,15 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, // cg::reduce functor for FusedL2NN used in its cutlass version // to output the min distance value & key(loc id). -template +template struct kvp_cg_reduce_op { typedef typename raft::KeyValuePair KVP; __host__ __device__ kvp_cg_reduce_op() noexcept {}; // functor signature. - __host__ __device__ KVP operator()(KVP a, KVP b) const - { - return a.value < b.value ? a : b; - } - __host__ __device__ AccType operator()(AccType a, AccType b) const - { - return a < b ? a : b; - } - + __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } + __host__ __device__ AccType operator()(AccType a, AccType b) const { return a < b ? a : b; } }; template = 8) { - using L2Op = L2ExpandedOp; + using L2Op = L2ExpandedOp; using kvp_cg_reduce_op_ = kvp_cg_reduce_op; kvp_cg_reduce_op_ cg_reduce_op; L2Op L2_dist_op(sqrt); @@ -343,37 +332,58 @@ void fusedL2NNImpl(OutT* min, IdxT lda, ldb, ldd; lda = k, ldb = k, ldd = n; - cutlassFusedL2NNKernel(x, y, xn, yn, m, n, k, - lda, ldb, ldd, min, workspace, cg_reduce_op, L2_dist_op, - redOp, pairRedOp, stream); + cutlassFusedL2NNKernel(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + L2_dist_op, + redOp, + pairRedOp, + stream); } else { - auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; + auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); if (sqrt) { auto fusedL2NNSqrt = fusedL2NNkernel; + OutT, + IdxT, + true, + P, + ReduceOpT, + KVPReduceOpT, + decltype(core_lambda), + decltype(fin_op)>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); fusedL2NNSqrt<<>>( min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); } else { auto fusedL2NN = fusedL2NNkernel; + OutT, + IdxT, + false, + P, + ReduceOpT, + KVPReduceOpT, + decltype(core_lambda), + decltype(fin_op)>; dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); fusedL2NN<<>>( min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, core_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh old mode 100755 new mode 100644 index 31175d12bd..d0039b1627 --- a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh @@ -39,7 +39,6 @@ #include "./fused_l2_nn_epilogue_elementwise.cuh" #include "./fused_l2_nn_gemm.h" - #define CUTLASS_CHECK(status) \ { \ cutlass::Status error = status; \ @@ -64,22 +63,22 @@ template void cutlassFusedL2NNKernel(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - int* mutexes, - FinalLambda fin_op, - DistanceFn dist_op, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - cudaStream_t stream) + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + int* mutexes, + FinalLambda fin_op, + DistanceFn dist_op, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + cudaStream_t stream) { using EpilogueOutputOp = cutlass::epilogue::thread::FusedL2NNEpilogueElementwise::GemmKernel; + Alignment, + DataT, + Alignment, + AccT, + AccT, + EpilogueOutputOp, + NumStages, // Number of pipeline stages + isRowMajor>::GemmKernel; using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; @@ -131,7 +130,12 @@ void cutlassFusedL2NNKernel(const DataT* x, gemm_ldb = lda; typename cutlassDist::Arguments arguments{ - mode, problem_size, batch_count, epilog_op_param, a, b, + mode, + problem_size, + batch_count, + epilog_op_param, + a, + b, xn, // C matrix eq vector param, which here is A norm nullptr, // tensor_Z, (DataT*)yn, // this is broadcast vec, which is required to be non-const param @@ -140,14 +144,14 @@ void cutlassFusedL2NNKernel(const DataT* x, (int64_t)0, // batch stride B (int64_t)0, // batch stride Norm A (int64_t)0, - (int64_t)0, // batch stride Norm B - (int64_t)0, // batch stride Output - (int64_t)gemm_lda, // stride A - (int64_t)gemm_ldb, // stride B - 1, // stride A norm - 0, // this is no-op for Z - 0, // This must be zero - (int64_t)ldd // stride Output matrix + (int64_t)0, // batch stride Norm B + (int64_t)0, // batch stride Output + (int64_t)gemm_lda, // stride A + (int64_t)gemm_ldb, // stride B + 1, // stride A norm + 0, // this is no-op for Z + 0, // This must be zero + (int64_t)ldd // stride Output matrix }; // Using the arguments, query for extra workspace required for matrix multiplication computation @@ -167,8 +171,8 @@ void cutlassFusedL2NNKernel(const DataT* x, CUTLASS_CHECK(status); } -}; // namespace detail -}; // namespace distance -}; // namespace raft +}; // namespace detail +}; // namespace distance +}; // namespace raft #pragma GCC diagnostic pop diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh old mode 100755 new mode 100644 index d507d59c5d..2a23bc2733 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -76,9 +76,11 @@ struct FusedL2NNEpilogue { // // Additional tensor tile iterator - stores t = Elementwise(z) // - using OutputTileIterator = - cutlass::epilogue::threadblock::PredicatedTileIteratorReducedVec; + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorReducedVec< + typename Base::OutputTileThreadMap, + ElementTensor, + LayoutT, + typename OutputOp::Params>; /// Define the epilogue using Epilogue = EpilogueWithBroadcast; using FragmentCompute = Array; @@ -89,16 +89,23 @@ class FusedL2NNEpilogueElementwise { DistanceOp_ dist_op_; KVPReduceOpT_ pair_redop_; ReduceOpT_ red_op_; - volatile int *mutexes_; + volatile int* mutexes_; // // Methods // CUTLASS_HOST_DEVICE - Params(DistanceOp_ dist_op, CGReduceOp cg_reduce_op, - ReduceOpT_ red_op, KVPReduceOpT_ pair_redop, - volatile int *mutexes) : - cg_reduce_op(cg_reduce_op), dist_op_(dist_op), pair_redop_(pair_redop), - red_op_(red_op), mutexes_(mutexes) {} + Params(DistanceOp_ dist_op, + CGReduceOp cg_reduce_op, + ReduceOpT_ red_op, + KVPReduceOpT_ pair_redop, + volatile int* mutexes) + : cg_reduce_op(cg_reduce_op), + dist_op_(dist_op), + pair_redop_(pair_redop), + red_op_(red_op), + mutexes_(mutexes) + { + } CUTLASS_HOST_DEVICE Params() {} @@ -120,8 +127,7 @@ class FusedL2NNEpilogueElementwise { /// Constructor from Params CUTLASS_HOST_DEVICE FusedL2NNEpilogueElementwise(Params const& params) - : elementwise_op(params.dist_op_), - pair_redop(params.pair_redop_), red_op(params.red_op_) + : elementwise_op(params.dist_op_), pair_redop(params.pair_redop_), red_op(params.red_op_) { } @@ -156,7 +162,6 @@ class FusedL2NNEpilogueElementwise { ElementCompute res_Z = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); red_op.init(&frag_T[i], res_Z); } - } /// Applies the operation when is_source_needed() is false diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h old mode 100755 new mode 100644 index aa20f348c2..6e44897880 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,8 +23,8 @@ #include #include -#include "./fused_l2_nn_epilogue.cuh" #include "./fusedL2NN_gemm_with_fused_epilogue.h" +#include "./fused_l2_nn_epilogue.cuh" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -60,8 +60,7 @@ struct FusedL2NNGemm { /// Threadblock-level tile size (concept: GemmShape) // <- threadblock tile M = 32, N = 64, K = 16 - using ThreadblockShape = - cutlass::gemm::GemmShape<32, 64, 16>; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute // <- warp tile M = 64, N = 64, K = 16 @@ -133,7 +132,8 @@ struct FusedL2NNGemm { GemmBase::Epilogue::kElementsPerAccess>::Epilogue; // Compose the GEMM kernel - using GemmKernel = FusedL2NNWithFusedEpilogue; + using GemmKernel = + FusedL2NNWithFusedEpilogue; }; template < @@ -152,15 +152,14 @@ template < /// data layout row/column major of inputs bool isRowMajor> struct FusedL2NNGemm { - + kAlignmentA, + double, + kAlignmentB, + ElementC_, + ElementAccumulator, + EpilogueOutputOp, + Stages, + isRowMajor> { // Threadblock-level tile size (concept: GemmShape) // <- threadblock tile M = 64, N = 64, K = 16 using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; @@ -233,7 +232,8 @@ struct FusedL2NNGemm::Epilogue; // Compose the GEMM kernel - using GemmKernel = FusedL2NNWithFusedEpilogue; + using GemmKernel = + FusedL2NNWithFusedEpilogue; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h old mode 100755 new mode 100644 index b6808f915d..ab428636be --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,8 @@ This file contains a customized version of PredicatedTileIterator from CUTLASS 2 #pragma once +#include +#include #include #include #include @@ -41,8 +43,6 @@ This file contains a customized version of PredicatedTileIterator from CUTLASS 2 #include #include #include -#include -#include namespace cg = cooperative_groups; @@ -78,12 +78,12 @@ class PredicatedTileIteratorReducedVec { using TensorRef = TensorRef; using ConstTensorRef = typename TensorRef::ConstTensorRef; - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; using EpilogueOpParams = EpilogueOpParams_; - //static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + // static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; static int const kElementsPerAccess = 1; static int const kThreads = ThreadMap::kThreads; static int const kIterations = ThreadMap::Count::kTile; @@ -94,10 +94,10 @@ class PredicatedTileIteratorReducedVec { static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); /// Fragment object - using Fragment = Array; + using Fragment = + Array; /// Memory access size using AccessType = AlignedArray; @@ -124,10 +124,9 @@ class PredicatedTileIteratorReducedVec { CUTLASS_HOST_DEVICE Params(Layout const& layout, EpilogueOpParams const& user_param_) - : PredicatedTileIteratorParams( - int(sizeof(AccessType)) / kElementsPerAccess, - make_OutputTileThreadMapDesc()), - user_param(user_param_) + : PredicatedTileIteratorParams(int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()), + user_param(user_param_) { } @@ -219,11 +218,11 @@ class PredicatedTileIteratorReducedVec { /// Constructor CUTLASS_DEVICE PredicatedTileIteratorReducedVec(Params const& params, - Element* pointer, - TensorCoord extent, - int thread_idx, - TensorCoord threadblock_offset = TensorCoord(), - int const* indices = nullptr) + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) : params_(params), indices_(indices) { TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; @@ -329,10 +328,10 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const { - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); - cg::thread_block cta = cg::this_thread_block(); + cg::thread_block cta = cg::this_thread_block(); cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); CUTLASS_PRAGMA_UNROLL @@ -371,25 +370,25 @@ class PredicatedTileIteratorReducedVec { } } else { if (guard) { - const auto key_id = thread_start_column_ + ThreadMap::Delta::kColumn * column; + const auto key_id = thread_start_column_ + ThreadMap::Delta::kColumn * column; const int frag_col_idx = frag_idx + column; params_.user_param.red_op_.init_key((*frag_ptr)[frag_col_idx], key_id); - params_.user_param.red_op_(key_id, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_col_idx]); + params_.user_param.red_op_( + key_id, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_col_idx]); } } } bool col_guard = row_guard && mask_.predicates[0]; - auto subTile = cg::binary_partition(tile32, col_guard); + auto subTile = cg::binary_partition(tile32, col_guard); if (col_guard) { - (*frag_ptr)[frag_idx] = cg::reduce(subTile, (*frag_ptr)[frag_idx], params_.user_param.cg_reduce_op); + (*frag_ptr)[frag_idx] = + cg::reduce(subTile, (*frag_ptr)[frag_idx], params_.user_param.cg_reduce_op); } if (tile32.thread_rank() > 0) { if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { - byte_pointer += params_.increment_row; - } + if (!ScatterD) { byte_pointer += params_.increment_row; } } } } @@ -431,32 +430,32 @@ class PredicatedTileIteratorReducedVec { // this performs better for most of the cases than per warp lock. if (threadIdx.x == 0 && thread_start_row_ < extent_row_) { // acquire mutex lock. - volatile int *row_mutex = params_.user_param.mutexes_ + thread_start_row_; - while (atomicCAS((int*)row_mutex, 0, 1) == 1); + volatile int* row_mutex = params_.user_param.mutexes_ + thread_start_row_; + while (atomicCAS((int*)row_mutex, 0, 1) == 1) + ; } __syncthreads(); if (tile32.thread_rank() == 0) { - CUTLASS_PRAGMA_UNROLL for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - const int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; + const int frag_row_idx = (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; bool row_guard = ((row_offset + thread_start_row_) < extent_row_); AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; - if (row_guard && mask_.predicates[0]) { + const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; + if (row_guard && mask_.predicates[0]) { // reduction with the current gmem value. params_.user_param.red_op_(row_offset + thread_start_row_, - (Element*)&memory_pointer[0], - (*frag_ptr)[frag_idx]); + (Element*)&memory_pointer[0], + (*frag_ptr)[frag_idx]); } if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { - byte_pointer += params_.increment_row; - } + if (!ScatterD) { byte_pointer += params_.increment_row; } } } } @@ -465,13 +464,11 @@ class PredicatedTileIteratorReducedVec { __threadfence(); if (threadIdx.x == 0 && thread_start_row_ < extent_row_) { // release mutex lock. - volatile int *row_mutex = params_.user_param.mutexes_ + thread_start_row_; + volatile int* row_mutex = params_.user_param.mutexes_ + thread_start_row_; atomicCAS((int*)row_mutex, 1, 0); } #endif - if (group + 1 < ThreadMap::Iterations::kGroup) { - byte_pointer += params_.increment_group; - } + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } } if (cluster + 1 < ThreadMap::Iterations::kCluster) { @@ -651,7 +648,6 @@ class PredicatedTileIteratorReducedVec { (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; if (state_[1] == ThreadMap::Count::kGroup) { - state_[1] = 0; ++state_[2]; byte_pointer_ += params_.advance_cluster; diff --git a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh index 8bc48332e6..fea7600723 100644 --- a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,7 +60,7 @@ struct FixConnectivitiesRedOp { value_idx m; // default constructor for cutlass - DI FixConnectivitiesRedOp() : colors(0), m(0) { } + DI FixConnectivitiesRedOp() : colors(0), m(0) {} FixConnectivitiesRedOp(value_idx* colors_, value_idx m_) : colors(colors_), m(m_){}; @@ -88,11 +88,8 @@ struct FixConnectivitiesRedOp { out->value = maxVal; } - DI void init_key(value_t &out, value_idx idx) const { return; } - DI void init_key(KVP &out, value_idx idx) const - { - out.key = idx; - } + DI void init_key(value_t& out, value_idx idx) const { return; } + DI void init_key(KVP& out, value_idx idx) const { out.key = idx; } }; /** From 207a96410e5f7be4b4e0d31fb45cb5fbeaee6e1f Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 4 Jan 2023 16:07:10 +0530 Subject: [PATCH 11/48] add comments to cutlass headers which are customized --- .../distance/detail/fusedL2NN_gemm_with_fused_epilogue.h | 6 ++++++ cpp/include/raft/distance/detail/fused_l2_nn_gemm.h | 1 - .../distance/detail/predicated_tile_iterator_reduced_vec.h | 5 +++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h index 9cec7b96b4..77aa7c8473 100644 --- a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h +++ b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h @@ -30,6 +30,12 @@ **************************************************************************************************/ /*! \file \brief Gemm kernel with fused reduction operation. + +This file contains a customized version of GemmWithFusedEpilogue from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h) +* Changes: +-- added additional input parameter to params_Tensor constructor, + for passing user inputs to PredicatedTileIterator of reduced output values. */ #pragma once diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h index 6e44897880..279f2edd46 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -19,7 +19,6 @@ #include #include -//#include #include #include diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index ab428636be..27a897413b 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -22,8 +22,9 @@ This file contains a customized version of PredicatedTileIterator from CUTLASS 2 Changes: - added `Layout_` template param -- Only the row index is used to load the data in load_with_byte_offset(). - This way the same normalization data is used across all columns in a row. +- PredicatedTileIteratorParams() is customized to not stride by layout.stride(0). +- customized the store_with_byte_offset() to perform reduction per row and write final value to gmem. +- customized the Params() struct to take user inputs from epilogueOp params. */ From ff221546b1689ef46f6afe83edd08245b7f524d3 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 4 Jan 2023 20:49:41 +0530 Subject: [PATCH 12/48] fix clang format issues --- .../distance/detail/predicated_tile_iterator_reduced_vec.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index 27a897413b..102cda91ee 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -23,7 +23,8 @@ This file contains a customized version of PredicatedTileIterator from CUTLASS 2 Changes: - added `Layout_` template param - PredicatedTileIteratorParams() is customized to not stride by layout.stride(0). -- customized the store_with_byte_offset() to perform reduction per row and write final value to gmem. +- customized the store_with_byte_offset() to perform reduction per row and write final value to +gmem. - customized the Params() struct to take user inputs from epilogueOp params. */ From bf8f271a134abaea809249f22d933686d93c7e5c Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 3 Feb 2023 10:04:30 +0000 Subject: [PATCH 13/48] fix style checks --- .../detail/fusedL2NN_gemm_with_fused_epilogue.h | 16 ++++++++-------- .../distance/detail/fused_l2_nn_cutlass_base.cuh | 4 ++-- .../distance/detail/fused_l2_nn_epilogue.cuh | 4 ++-- .../raft/distance/detail/fused_l2_nn_gemm.h | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h index 77aa7c8473..bc98ddaa11 100644 --- a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h +++ b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h @@ -40,14 +40,14 @@ This file contains a customized version of GemmWithFusedEpilogue from CUTLASS 2. #pragma once -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -#include "cutlass/trace.h" +#include +#include +#include +#include +#include +#include + +#include ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh index d0039b1627..f59edc7182 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh @@ -36,8 +36,8 @@ #include #include -#include "./fused_l2_nn_epilogue_elementwise.cuh" -#include "./fused_l2_nn_gemm.h" +#include +#include #define CUTLASS_CHECK(status) \ { \ diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh index 2a23bc2733..f7fa0d0731 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh @@ -36,8 +36,8 @@ operation. #include -#include "./predicated_tile_iterator_normvec.h" -#include "./predicated_tile_iterator_reduced_vec.h" +#include +#include #include #include #include diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h index 279f2edd46..d1e73e3954 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -22,8 +22,8 @@ #include #include -#include "./fusedL2NN_gemm_with_fused_epilogue.h" -#include "./fused_l2_nn_epilogue.cuh" +#include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// From c9421a94292ed63e6902d6573e0625b8b9ea91fa Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 3 Feb 2023 12:30:33 +0000 Subject: [PATCH 14/48] fix the style checks for header include --- cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh | 4 ++-- cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh | 4 ++-- cpp/include/raft/distance/detail/fused_l2_nn_gemm.h | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh index f59edc7182..6c87ac0a55 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh @@ -36,8 +36,8 @@ #include #include -#include -#include +#include +#include #define CUTLASS_CHECK(status) \ { \ diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh index f7fa0d0731..32bd117b91 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh @@ -36,8 +36,8 @@ operation. #include -#include -#include +#include +#include #include #include #include diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h index d1e73e3954..f59b9d0b4a 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -22,8 +22,8 @@ #include #include -#include -#include +#include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// From 9215abe948e9cf46bda924db151e36e024073f80 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 3 Feb 2023 13:08:59 +0000 Subject: [PATCH 15/48] fix style check in fused_l2_knn --- cpp/test/neighbors/fused_l2_knn.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index 981b3daa2b..349cb6167d 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "../test_utils.cuh" +#include #include #include From 08745d1a9efc8e4966631c6087f3f3fbae96be27 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 3 Feb 2023 13:20:01 +0000 Subject: [PATCH 16/48] fix style issues --- cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh | 5 +++-- cpp/test/neighbors/fused_l2_knn.cu | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh index 32bd117b91..8d94de3378 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh @@ -36,13 +36,14 @@ operation. #include -#include -#include #include #include #include #include +#include +#include + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index 349cb6167d..57ce28fd6e 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -17,11 +17,11 @@ #include #include +#include #include #include #include #include -#include #if defined RAFT_NN_COMPILED #include From e7470e88f6aef43e2904e03ebec8998486895d26 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 23 Feb 2023 08:37:27 +0000 Subject: [PATCH 17/48] use new gemm shape 32,128,16, make use of shared mem for reduction in tile iterator, this both improves perf by 20%+ compared to previous version by reducing gmem atomics+coalesced stores --- .../fusedL2NN_gemm_with_fused_epilogue.h | 15 +- .../raft/distance/detail/fused_l2_nn.cuh | 10 +- .../raft/distance/detail/fused_l2_nn_gemm.h | 6 +- .../predicated_tile_iterator_reduced_vec.h | 236 ++++++++++-------- 4 files changed, 154 insertions(+), 113 deletions(-) diff --git a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h index bc98ddaa11..2fae243bf6 100644 --- a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h +++ b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h @@ -363,10 +363,15 @@ struct FusedL2NNWithFusedEpilogue { } }; + struct epilogue_SharedStorage { + typename Epilogue::SharedStorage epilogue; + typename Epilogue::TensorTileIterator::SharedStorage reduced_store; + }; + /// Shared memory storage structure union SharedStorage { typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; + epilogue_SharedStorage epilogue_combined_store; }; public: @@ -576,7 +581,8 @@ struct FusedL2NNWithFusedEpilogue { params.params_D, ptr_D, params.problem_size.mn(), thread_idx, threadblock_offset); // Additional tensor to load from - typename Epilogue::TensorTileIterator tensor_iterator(params.params_Tensor, + typename Epilogue::TensorTileIterator tensor_iterator(shared_storage.epilogue_combined_store.reduced_store, + params.params_Tensor, // Only the final block outputs Tensor ptr_Tensor, params.problem_size.mn(), @@ -584,7 +590,7 @@ struct FusedL2NNWithFusedEpilogue { threadblock_offset); // Construct the epilogue - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + Epilogue epilogue(shared_storage.epilogue_combined_store.epilogue, thread_idx, warp_idx, lane_idx); // Move to appropriate location for this output tile if (ptr_Vector) { @@ -652,6 +658,7 @@ struct FusedL2NNWithFusedEpilogue { // Additional tensor to load from typename Epilogue::TensorTileIterator tensor_iterator( + shared_storage.epilogue_combined_store.reduced_store, params.params_Tensor, // Only the final block outputs Tensor ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && @@ -663,7 +670,7 @@ struct FusedL2NNWithFusedEpilogue { threadblock_offset); // Construct the epilogue - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + Epilogue epilogue(shared_storage.epilogue_combined_store.epilogue, thread_idx, warp_idx, lane_idx); #if SPLIT_K_ENABLED // Wait on the semaphore - this latency may have been covered by iterator construction diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index b7ee95795f..1a725a023e 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -272,10 +272,10 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, // cg::reduce functor for FusedL2NN used in its cutlass version // to output the min distance value & key(loc id). template -struct kvp_cg_reduce_op { +struct kvp_cg_min_reduce_op { typedef typename raft::KeyValuePair KVP; - __host__ __device__ kvp_cg_reduce_op() noexcept {}; + __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; // functor signature. __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } @@ -325,8 +325,8 @@ void fusedL2NNImpl(OutT* min, if (deviceVersion.first >= 8) { using L2Op = L2ExpandedOp; - using kvp_cg_reduce_op_ = kvp_cg_reduce_op; - kvp_cg_reduce_op_ cg_reduce_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; L2Op L2_dist_op(sqrt); IdxT lda, ldb, ldd; @@ -337,7 +337,7 @@ void fusedL2NNImpl(OutT* min, OutT, IdxT, P::Veclen, - kvp_cg_reduce_op_, + kvp_cg_min_reduce_op_, L2Op, ReduceOpT, KVPReduceOpT>(x, diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h index f59b9d0b4a..0ae85d19e1 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -59,11 +59,13 @@ struct FusedL2NNGemm { /// Threadblock-level tile size (concept: GemmShape) // <- threadblock tile M = 32, N = 64, K = 16 - using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; // this is more performant + /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute // <- warp tile M = 64, N = 64, K = 16 - using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; + using WarpShape = cutlass::gemm::GemmShape<16, 64, 16>; // this is more performant + /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op // <- MMA Op tile M = 16, N = 8, K = 4 diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index 102cda91ee..1ec3828854 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -23,6 +23,7 @@ This file contains a customized version of PredicatedTileIterator from CUTLASS 2 Changes: - added `Layout_` template param - PredicatedTileIteratorParams() is customized to not stride by layout.stride(0). +- makes use of `SharedStorage` to store reduced values across warps to gmem in coalesced manner. - customized the store_with_byte_offset() to perform reduction per row and write final value to gmem. - customized the Params() struct to take user inputs from epilogueOp params. @@ -94,6 +95,7 @@ class PredicatedTileIteratorReducedVec { static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + static_assert(!UseCUDAStore, "UseCUDAStore path is not supported"); /// Fragment object using Fragment = @@ -168,6 +170,71 @@ class PredicatedTileIteratorReducedVec { } }; +/// Shared storage allocation needed by the predicated tile + // iterator for reduction. + struct SharedStorage { + // + // Type definitions + // + using Shape = MatrixShape< + ThreadMap::kWarpCount * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * + ThreadMap::Count::kTile, + 1 + >; + + /// Shape of the shared memory allocation for the reduced values store + using StorageShape = MatrixShape< + Shape::kRow, + Shape::kColumn + >; + + // + // Data members + // + static const int warp_row_stride = ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster; + static const int tile_row_stride = ThreadMap::kWarpCount * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster; + + // + // Methods + // + AlignedBuffer storage; + AlignedBuffer storage_gmem_ptr; + + CUTLASS_DEVICE + Element *data() { + return storage.data(); + } + + CUTLASS_DEVICE + Element *warp_data(int warp_id, int tile_id) { + + return data() + warp_id * warp_row_stride + tile_id * tile_row_stride; + } + + CUTLASS_DEVICE + Element **gmem_ptr_data() { + return storage_gmem_ptr.data(); + } + + + CUTLASS_DEVICE + Element **gmem_ptr_warp_data(int warp_id, int tile_id) { + + return gmem_ptr_data() + warp_id * warp_row_stride + tile_id * tile_row_stride; + } + + SharedStorage() { } + + }; + private: // // Data members @@ -190,12 +257,14 @@ class PredicatedTileIteratorReducedVec { /// A thread's starting row position (assuming steady-state predicates have been computed) Index thread_start_row_; + Index thread_start_row_first_tile_; /// A thread's starting column Index thread_start_column_; /// Internal state counter int state_[3]; + mutable int shared_tile_id; /// Scatter indices int const* indices_; @@ -208,6 +277,9 @@ class PredicatedTileIteratorReducedVec { static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); static_assert(sizeof(Params::stride) == 8, "Expected 64b strides"); +protected: + SharedStorage &shared_storage_; + private: // // Methods @@ -219,13 +291,14 @@ class PredicatedTileIteratorReducedVec { // /// Constructor CUTLASS_DEVICE - PredicatedTileIteratorReducedVec(Params const& params, + PredicatedTileIteratorReducedVec(SharedStorage &shared_storage, + Params const& params, Element* pointer, TensorCoord extent, int thread_idx, TensorCoord threadblock_offset = TensorCoord(), int const* indices = nullptr) - : params_(params), indices_(indices) + : params_(params), indices_(indices), shared_storage_(shared_storage) { TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; @@ -234,6 +307,8 @@ class PredicatedTileIteratorReducedVec { thread_start_row_ = thread_offset.row(); thread_start_column_ = thread_offset.column(); + thread_start_row_first_tile_ = thread_start_row_; + shared_tile_id = 0; // Initialize predicates CUTLASS_PRAGMA_UNROLL @@ -336,6 +411,9 @@ class PredicatedTileIteratorReducedVec { cg::thread_block cta = cg::this_thread_block(); cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); + Element* warp_shared_elem_arr = shared_storage_.warp_data(tile32.meta_group_rank(), shared_tile_id); + Element** warp_gmem_ptrs = shared_storage_.gmem_ptr_warp_data(tile32.meta_group_rank(), shared_tile_id); + CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { CUTLASS_PRAGMA_UNROLL @@ -352,33 +430,18 @@ class PredicatedTileIteratorReducedVec { AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - if (ScatterD && row_guard) { - assert(indices_); - - memory_pointer = reinterpret_cast( - byte_pointer + byte_offset + - LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); - } - const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { bool guard = row_guard && mask_.predicates[column]; - if (UseCUDAStore) { - if (guard) { - memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; - } - } else { - if (guard) { - const auto key_id = thread_start_column_ + ThreadMap::Delta::kColumn * column; - const int frag_col_idx = frag_idx + column; - params_.user_param.red_op_.init_key((*frag_ptr)[frag_col_idx], key_id); - params_.user_param.red_op_( - key_id, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_col_idx]); - } + if (guard) { + const auto key_id = thread_start_column_ + ThreadMap::Delta::kColumn * column; + const int frag_col_idx = frag_idx + column; + params_.user_param.red_op_.init_key((*frag_ptr)[frag_col_idx], key_id); + params_.user_param.red_op_(key_id, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_col_idx]); } + } bool col_guard = row_guard && mask_.predicates[0]; auto subTile = cg::binary_partition(tile32, col_guard); @@ -386,90 +449,22 @@ class PredicatedTileIteratorReducedVec { if (col_guard) { (*frag_ptr)[frag_idx] = cg::reduce(subTile, (*frag_ptr)[frag_idx], params_.user_param.cg_reduce_op); - } - - if (tile32.thread_rank() > 0) { - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { byte_pointer += params_.increment_row; } - } - } - } -#if 0 - // single lock per warp for multiple rows - if (tile32.thread_rank() == 0 && thread_start_row_ < extent_row_) { - - volatile int *row_mutex = params_.user_param.mutexes_ + thread_start_row_; - while (atomicCAS((int*)row_mutex, 0, 1) == 1); - __threadfence(); - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - const int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; - if (row_guard && mask_.predicates[0]) { - params_.user_param.red_op_(row_offset + thread_start_row_, - (Element*)&memory_pointer[0], - (*frag_ptr)[frag_idx]); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { - byte_pointer += params_.increment_row; + if (subTile.thread_rank() == 0) { + warp_shared_elem_arr[row] = (*frag_ptr)[frag_idx]; + warp_gmem_ptrs[row] = (Element*) &memory_pointer[0]; } - } } - __threadfence(); - atomicCAS((int*)row_mutex, 1, 0); - } -#else - // single lock per block for multiple rows - // this performs better for most of the cases than per warp lock. - if (threadIdx.x == 0 && thread_start_row_ < extent_row_) { - // acquire mutex lock. - volatile int* row_mutex = params_.user_param.mutexes_ + thread_start_row_; - while (atomicCAS((int*)row_mutex, 0, 1) == 1) - ; - } - __syncthreads(); - if (tile32.thread_rank() == 0) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - const int frag_row_idx = (row + ThreadMap::Iterations::kRow * - (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; - if (row_guard && mask_.predicates[0]) { - // reduction with the current gmem value. - params_.user_param.red_op_(row_offset + thread_start_row_, - (Element*)&memory_pointer[0], - (*frag_ptr)[frag_idx]); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { byte_pointer += params_.increment_row; } - } + + if (!row_guard) { + if (tile32.thread_rank() == 0) { + warp_gmem_ptrs[row] = nullptr; + } + } + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } } } - __syncthreads(); - __threadfence(); - if (threadIdx.x == 0 && thread_start_row_ < extent_row_) { - // release mutex lock. - volatile int* row_mutex = params_.user_param.mutexes_ + thread_start_row_; - atomicCAS((int*)row_mutex, 1, 0); - } -#endif if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } } @@ -477,6 +472,42 @@ class PredicatedTileIteratorReducedVec { byte_pointer += params_.increment_cluster; } } + + // If this is last tile then perform reduction in gmem. + if (shared_tile_id == (ThreadMap::Count::kTile - 1)) { + // single lock per block for multiple rows + if (threadIdx.x == 0 && thread_start_row_first_tile_ < extent_row_) { + + // acquire mutex lock. + volatile int* row_mutex = params_.user_param.mutexes_ + thread_start_row_first_tile_; + while (atomicCAS((int*)row_mutex, 0, 1) == 1) + ; + } + __syncthreads(); + + auto shared_elem_arr = shared_storage_.data(); + auto shared_gmem_ptr = shared_storage_.gmem_ptr_data(); + + static int const num_of_vals = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::Count::kTile; + + for (int row = threadIdx.x; row < num_of_vals; row += blockDim.x) { + auto gmem_ptr = shared_gmem_ptr[row]; + if (gmem_ptr) { + params_.user_param.red_op_(0, gmem_ptr, shared_elem_arr[row]); + } + } + + __threadfence(); + __syncthreads(); + if (threadIdx.x == 0 && thread_start_row_first_tile_ < extent_row_) { + // release mutex lock. + volatile int* row_mutex = params_.user_param.mutexes_ + thread_start_row_first_tile_; + atomicCAS((int*)row_mutex, 1, 0); + } + } + } /// Stores a fragment to memory @@ -636,6 +667,7 @@ class PredicatedTileIteratorReducedVec { PredicatedTileIteratorReducedVec& operator++() { ++state_[0]; + shared_tile_id++; // tile iteration. if (!ScatterD) { byte_pointer_ += params_.advance_row; } From 20f7fb62561a75df656fb6a519ccc10dbca49a72 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 23 Feb 2023 12:40:21 +0000 Subject: [PATCH 18/48] fix formatting issues, move cutlass_check to a separate header following raft_cuda_try --- .../fusedL2NN_gemm_with_fused_epilogue.h | 23 ++-- .../raft/distance/detail/fused_l2_nn.cuh | 2 +- .../detail/fused_l2_nn_cutlass_base.cuh | 20 +--- .../fused_l2_nn_epilogue_elementwise.cuh | 5 +- .../raft/distance/detail/fused_l2_nn_gemm.h | 2 +- .../predicated_tile_iterator_reduced_vec.h | 102 +++++++----------- 6 files changed, 63 insertions(+), 91 deletions(-) diff --git a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h index 2fae243bf6..b933ab9a7b 100644 --- a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h +++ b/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h @@ -302,7 +302,7 @@ struct FusedL2NNWithFusedEpilogue { params_B(args.ldb), params_C(args.ldc), params_D(args.ldd), - // Here we additional pass user args via args.epilogue + // Here we pass additional user args via args.epilogue params_Tensor(args.ldt, args.epilogue), output_op(args.epilogue), mode(args.mode), @@ -581,16 +581,18 @@ struct FusedL2NNWithFusedEpilogue { params.params_D, ptr_D, params.problem_size.mn(), thread_idx, threadblock_offset); // Additional tensor to load from - typename Epilogue::TensorTileIterator tensor_iterator(shared_storage.epilogue_combined_store.reduced_store, - params.params_Tensor, - // Only the final block outputs Tensor - ptr_Tensor, - params.problem_size.mn(), - thread_idx, - threadblock_offset); + typename Epilogue::TensorTileIterator tensor_iterator( + shared_storage.epilogue_combined_store.reduced_store, + params.params_Tensor, + // Only the final block outputs Tensor + ptr_Tensor, + params.problem_size.mn(), + thread_idx, + threadblock_offset); // Construct the epilogue - Epilogue epilogue(shared_storage.epilogue_combined_store.epilogue, thread_idx, warp_idx, lane_idx); + Epilogue epilogue( + shared_storage.epilogue_combined_store.epilogue, thread_idx, warp_idx, lane_idx); // Move to appropriate location for this output tile if (ptr_Vector) { @@ -670,7 +672,8 @@ struct FusedL2NNWithFusedEpilogue { threadblock_offset); // Construct the epilogue - Epilogue epilogue(shared_storage.epilogue_combined_store.epilogue, thread_idx, warp_idx, lane_idx); + Epilogue epilogue( + shared_storage.epilogue_combined_store.epilogue, thread_idx, warp_idx, lane_idx); #if SPLIT_K_ENABLED // Wait on the semaphore - this latency may have been covered by iterator construction diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 1a725a023e..97e882f62a 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -324,7 +324,7 @@ void fusedL2NNImpl(OutT* min, const auto deviceVersion = getComputeCapability(); if (deviceVersion.first >= 8) { - using L2Op = L2ExpandedOp; + using L2Op = L2ExpandedOp; using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; kvp_cg_min_reduce_op_ cg_reduce_op; L2Op L2_dist_op(sqrt); diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh index 6c87ac0a55..bd910c0240 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh @@ -38,16 +38,7 @@ #include #include - -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - if (error != cutlass::Status::kSuccess) { \ - std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ - << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } +#include namespace raft { namespace distance { @@ -161,14 +152,11 @@ void cutlassFusedL2NNKernel(const DataT* x, // Instantiate CUTLASS kernel depending on templates cutlassDist cutlassDist_op; // Check the problem size is supported or not - cutlass::Status status = cutlassDist_op.can_implement(arguments); - CUTLASS_CHECK(status); + RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); // Initialize CUTLASS kernel with arguments and workspace pointer - status = cutlassDist_op.initialize(arguments, workspace.data(), stream); - CUTLASS_CHECK(status); + RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); // Launch initialized CUTLASS kernel - status = cutlassDist_op(); - CUTLASS_CHECK(status); + RAFT_CUTLASS_TRY(cutlassDist_op()); } }; // namespace detail diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh index 216d1eec04..44852b32c0 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh @@ -20,7 +20,8 @@ * kernels. * This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0 * customized for applying elementwise distance formula on accumulated GEMM value -* and applying user-defined final custom operation on the distance value. +* and applying user-defined operation which can convert distance values to key-value pair. +* . */ #pragma once @@ -135,7 +136,7 @@ class FusedL2NNEpilogueElementwise { CUTLASS_HOST_DEVICE bool is_source_needed() const { - // we use for making sure C matrix path is used for A mat norm. + // we use for making sure C matrix is used for A mat norm. return true; } diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h index 0ae85d19e1..c2d8a7c507 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h @@ -59,7 +59,7 @@ struct FusedL2NNGemm { /// Threadblock-level tile size (concept: GemmShape) // <- threadblock tile M = 32, N = 64, K = 16 - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; // this is more performant + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; // this is more performant /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index 1ec3828854..7f87bc3a24 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -170,37 +170,28 @@ class PredicatedTileIteratorReducedVec { } }; -/// Shared storage allocation needed by the predicated tile + /// Shared storage allocation needed by the predicated tile // iterator for reduction. struct SharedStorage { // // Type definitions // - using Shape = MatrixShape< - ThreadMap::kWarpCount * - ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * - ThreadMap::Iterations::kCluster * - ThreadMap::Count::kTile, - 1 - >; + using Shape = MatrixShape; /// Shape of the shared memory allocation for the reduced values store - using StorageShape = MatrixShape< - Shape::kRow, - Shape::kColumn - >; + using StorageShape = MatrixShape; // // Data members // - static const int warp_row_stride = ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * - ThreadMap::Iterations::kCluster; - static const int tile_row_stride = ThreadMap::kWarpCount * - ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * - ThreadMap::Iterations::kCluster; + static const int warp_row_stride = + ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster; + static const int tile_row_stride = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster; // // Methods @@ -209,30 +200,24 @@ class PredicatedTileIteratorReducedVec { AlignedBuffer storage_gmem_ptr; CUTLASS_DEVICE - Element *data() { - return storage.data(); - } + Element* data() { return storage.data(); } CUTLASS_DEVICE - Element *warp_data(int warp_id, int tile_id) { - + Element* warp_data(int warp_id, int tile_id) + { return data() + warp_id * warp_row_stride + tile_id * tile_row_stride; } CUTLASS_DEVICE - Element **gmem_ptr_data() { - return storage_gmem_ptr.data(); - } - + Element** gmem_ptr_data() { return storage_gmem_ptr.data(); } CUTLASS_DEVICE - Element **gmem_ptr_warp_data(int warp_id, int tile_id) { - + Element** gmem_ptr_warp_data(int warp_id, int tile_id) + { return gmem_ptr_data() + warp_id * warp_row_stride + tile_id * tile_row_stride; } - SharedStorage() { } - + SharedStorage() {} }; private: @@ -277,9 +262,9 @@ class PredicatedTileIteratorReducedVec { static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); static_assert(sizeof(Params::stride) == 8, "Expected 64b strides"); -protected: - SharedStorage &shared_storage_; - + protected: + SharedStorage& shared_storage_; + private: // // Methods @@ -291,7 +276,7 @@ class PredicatedTileIteratorReducedVec { // /// Constructor CUTLASS_DEVICE - PredicatedTileIteratorReducedVec(SharedStorage &shared_storage, + PredicatedTileIteratorReducedVec(SharedStorage& shared_storage, Params const& params, Element* pointer, TensorCoord extent, @@ -305,10 +290,10 @@ class PredicatedTileIteratorReducedVec { extent_row_ = extent.row(); extent_column_ = extent.column(); - thread_start_row_ = thread_offset.row(); - thread_start_column_ = thread_offset.column(); + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); thread_start_row_first_tile_ = thread_start_row_; - shared_tile_id = 0; + shared_tile_id = 0; // Initialize predicates CUTLASS_PRAGMA_UNROLL @@ -411,8 +396,10 @@ class PredicatedTileIteratorReducedVec { cg::thread_block cta = cg::this_thread_block(); cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); - Element* warp_shared_elem_arr = shared_storage_.warp_data(tile32.meta_group_rank(), shared_tile_id); - Element** warp_gmem_ptrs = shared_storage_.gmem_ptr_warp_data(tile32.meta_group_rank(), shared_tile_id); + Element* warp_shared_elem_arr = + shared_storage_.warp_data(tile32.meta_group_rank(), shared_tile_id); + Element** warp_gmem_ptrs = + shared_storage_.gmem_ptr_warp_data(tile32.meta_group_rank(), shared_tile_id); CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { @@ -441,7 +428,6 @@ class PredicatedTileIteratorReducedVec { params_.user_param.red_op_.init_key((*frag_ptr)[frag_col_idx], key_id); params_.user_param.red_op_(key_id, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_col_idx]); } - } bool col_guard = row_guard && mask_.predicates[0]; auto subTile = cg::binary_partition(tile32, col_guard); @@ -449,16 +435,14 @@ class PredicatedTileIteratorReducedVec { if (col_guard) { (*frag_ptr)[frag_idx] = cg::reduce(subTile, (*frag_ptr)[frag_idx], params_.user_param.cg_reduce_op); - if (subTile.thread_rank() == 0) { - warp_shared_elem_arr[row] = (*frag_ptr)[frag_idx]; - warp_gmem_ptrs[row] = (Element*) &memory_pointer[0]; - } + if (subTile.thread_rank() == 0) { + warp_shared_elem_arr[row] = (*frag_ptr)[frag_idx]; + warp_gmem_ptrs[row] = (Element*)&memory_pointer[0]; + } } - + if (!row_guard) { - if (tile32.thread_rank() == 0) { - warp_gmem_ptrs[row] = nullptr; - } + if (tile32.thread_rank() == 0) { warp_gmem_ptrs[row] = nullptr; } } if (row + 1 < ThreadMap::Iterations::kRow) { if (!ScatterD) { byte_pointer += params_.increment_row; } @@ -477,26 +461,23 @@ class PredicatedTileIteratorReducedVec { if (shared_tile_id == (ThreadMap::Count::kTile - 1)) { // single lock per block for multiple rows if (threadIdx.x == 0 && thread_start_row_first_tile_ < extent_row_) { - // acquire mutex lock. volatile int* row_mutex = params_.user_param.mutexes_ + thread_start_row_first_tile_; while (atomicCAS((int*)row_mutex, 0, 1) == 1) ; } __syncthreads(); - + auto shared_elem_arr = shared_storage_.data(); auto shared_gmem_ptr = shared_storage_.gmem_ptr_data(); - static int const num_of_vals = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * - ThreadMap::Count::kTile; + static int const num_of_vals = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::Count::kTile; for (int row = threadIdx.x; row < num_of_vals; row += blockDim.x) { auto gmem_ptr = shared_gmem_ptr[row]; - if (gmem_ptr) { - params_.user_param.red_op_(0, gmem_ptr, shared_elem_arr[row]); - } + if (gmem_ptr) { params_.user_param.red_op_(0, gmem_ptr, shared_elem_arr[row]); } } __threadfence(); @@ -507,7 +488,6 @@ class PredicatedTileIteratorReducedVec { atomicCAS((int*)row_mutex, 1, 0); } } - } /// Stores a fragment to memory @@ -667,7 +647,7 @@ class PredicatedTileIteratorReducedVec { PredicatedTileIteratorReducedVec& operator++() { ++state_[0]; - shared_tile_id++; // tile iteration. + shared_tile_id++; // tile iteration. if (!ScatterD) { byte_pointer_ += params_.advance_row; } From 4941064b140ac759b5f48bed16828d2596db52cb Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 23 Feb 2023 13:27:31 +0000 Subject: [PATCH 19/48] add missed out cutlass_utils.cuh --- cpp/include/raft/util/cutlass_utils.cuh | 54 +++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 cpp/include/raft/util/cutlass_utils.cuh diff --git a/cpp/include/raft/util/cutlass_utils.cuh b/cpp/include/raft/util/cutlass_utils.cuh new file mode 100644 index 0000000000..3456c0c3e5 --- /dev/null +++ b/cpp/include/raft/util/cutlass_utils.cuh @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft { + +/** + * @brief Exception thrown when a CUDA error is encountered. + */ +struct cutlass_error : public raft::exception { + explicit cutlass_error(char const* const message) : raft::exception(message) {} + explicit cutlass_error(std::string const& message) : raft::exception(message) {} +}; + +} // namespace raft + +/** + * @brief Error checking macro for CUDA runtime API functions. + * + * Invokes a CUDA runtime API function call, if the call does not return + * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an + * exception detailing the CUDA error that occurred + * + */ +#define RAFT_CUTLASS_TRY(call) \ + do { \ + cutlass::Status const status = call; \ + if (status != cutlass::Status::kSuccess) { \ + std::string msg{}; \ + SET_ERROR_MSG(msg, \ + "CUTLASS error encountered at: ", \ + "call='%s', Reason=%s", \ + #call, \ + cutlassGetStatusString(status)); \ + throw raft::cutlass_error(msg); \ + } \ + } while (0) From 1b062d5c337b52a14e90c6017302669186431efc Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 9 Mar 2023 08:43:26 +0000 Subject: [PATCH 20/48] remove usage of shared mem gmem ptr storage, use block offset to store init ptr in regs, and only use consecutive mutexes instead of block strided mutex. overall improves perf slightly --- .../fused_l2_nn_epilogue_elementwise.cuh | 4 +- .../predicated_tile_iterator_reduced_vec.h | 96 +++++++------------ 2 files changed, 36 insertions(+), 64 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh index 44852b32c0..1d90e4f41d 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh @@ -90,7 +90,7 @@ class FusedL2NNEpilogueElementwise { DistanceOp_ dist_op_; KVPReduceOpT_ pair_redop_; ReduceOpT_ red_op_; - volatile int* mutexes_; + int* mutexes_; // // Methods // @@ -99,7 +99,7 @@ class FusedL2NNEpilogueElementwise { CGReduceOp cg_reduce_op, ReduceOpT_ red_op, KVPReduceOpT_ pair_redop, - volatile int* mutexes) + int* mutexes) : cg_reduce_op(cg_reduce_op), dist_op_(dist_op), pair_redop_(pair_redop), diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index 7f87bc3a24..d65df781b7 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -197,26 +197,10 @@ class PredicatedTileIteratorReducedVec { // Methods // AlignedBuffer storage; - AlignedBuffer storage_gmem_ptr; CUTLASS_DEVICE Element* data() { return storage.data(); } - CUTLASS_DEVICE - Element* warp_data(int warp_id, int tile_id) - { - return data() + warp_id * warp_row_stride + tile_id * tile_row_stride; - } - - CUTLASS_DEVICE - Element** gmem_ptr_data() { return storage_gmem_ptr.data(); } - - CUTLASS_DEVICE - Element** gmem_ptr_warp_data(int warp_id, int tile_id) - { - return gmem_ptr_data() + warp_id * warp_row_stride + tile_id * tile_row_stride; - } - SharedStorage() {} }; @@ -230,6 +214,8 @@ class PredicatedTileIteratorReducedVec { /// Byte-level pointer uint8_t* byte_pointer_; + /// Byte-level pointer first tile offset of this threadblock. + uint8_t* first_tile_byte_pointer_; /// Array of boolean values to contain steady-state predicates Mask mask_; @@ -242,7 +228,7 @@ class PredicatedTileIteratorReducedVec { /// A thread's starting row position (assuming steady-state predicates have been computed) Index thread_start_row_; - Index thread_start_row_first_tile_; + Index block_start_row_first_tile_; /// A thread's starting column Index thread_start_column_; @@ -290,10 +276,12 @@ class PredicatedTileIteratorReducedVec { extent_row_ = extent.row(); extent_column_ = extent.column(); - thread_start_row_ = thread_offset.row(); - thread_start_column_ = thread_offset.column(); - thread_start_row_first_tile_ = thread_start_row_; - shared_tile_id = 0; + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + TensorCoord block_offset = ThreadMap::initial_offset(0) + threadblock_offset; + block_start_row_first_tile_ = block_offset.row(); + shared_tile_id = 0; // Initialize predicates CUTLASS_PRAGMA_UNROLL @@ -311,6 +299,9 @@ class PredicatedTileIteratorReducedVec { byte_pointer_ = reinterpret_cast(pointer) + LongIndex(thread_offset.row()) * LongIndex(params_.stride); + first_tile_byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(block_offset.row()) * LongIndex(params_.stride); + if (ScatterD) { byte_pointer_ = reinterpret_cast(pointer) + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; @@ -390,16 +381,17 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const { - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); + AccessType* frag_ptr = reinterpret_cast(&frag); cg::thread_block cta = cg::this_thread_block(); cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); - Element* warp_shared_elem_arr = - shared_storage_.warp_data(tile32.meta_group_rank(), shared_tile_id); - Element** warp_gmem_ptrs = - shared_storage_.gmem_ptr_warp_data(tile32.meta_group_rank(), shared_tile_id); + Element* shared_elem_arr = shared_storage_.data(); + EpilogueOpParams const& user_params = params_.user_param; + + static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::Count::kTile; CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { @@ -415,8 +407,6 @@ class PredicatedTileIteratorReducedVec { bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { @@ -425,8 +415,8 @@ class PredicatedTileIteratorReducedVec { if (guard) { const auto key_id = thread_start_column_ + ThreadMap::Delta::kColumn * column; const int frag_col_idx = frag_idx + column; - params_.user_param.red_op_.init_key((*frag_ptr)[frag_col_idx], key_id); - params_.user_param.red_op_(key_id, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_col_idx]); + user_params.red_op_.init_key((*frag_ptr)[frag_col_idx], key_id); + user_params.red_op_(key_id, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_col_idx]); } } bool col_guard = row_guard && mask_.predicates[0]; @@ -434,58 +424,40 @@ class PredicatedTileIteratorReducedVec { if (col_guard) { (*frag_ptr)[frag_idx] = - cg::reduce(subTile, (*frag_ptr)[frag_idx], params_.user_param.cg_reduce_op); + cg::reduce(subTile, (*frag_ptr)[frag_idx], user_params.cg_reduce_op); if (subTile.thread_rank() == 0) { - warp_shared_elem_arr[row] = (*frag_ptr)[frag_idx]; - warp_gmem_ptrs[row] = (Element*)&memory_pointer[0]; + int iter_row = ((row_offset + thread_start_row_) % total_rows); + shared_elem_arr[iter_row] = (*frag_ptr)[frag_idx]; } } - - if (!row_guard) { - if (tile32.thread_rank() == 0) { warp_gmem_ptrs[row] = nullptr; } - } - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { byte_pointer += params_.increment_row; } - } } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; } } // If this is last tile then perform reduction in gmem. if (shared_tile_id == (ThreadMap::Count::kTile - 1)) { + const auto mutex_id = (block_start_row_first_tile_ / total_rows); // single lock per block for multiple rows - if (threadIdx.x == 0 && thread_start_row_first_tile_ < extent_row_) { + if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { // acquire mutex lock. - volatile int* row_mutex = params_.user_param.mutexes_ + thread_start_row_first_tile_; - while (atomicCAS((int*)row_mutex, 0, 1) == 1) + while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) ; } __syncthreads(); - auto shared_elem_arr = shared_storage_.data(); - auto shared_gmem_ptr = shared_storage_.gmem_ptr_data(); - - static int const num_of_vals = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * - ThreadMap::Iterations::kCluster * ThreadMap::Count::kTile; + auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - for (int row = threadIdx.x; row < num_of_vals; row += blockDim.x) { - auto gmem_ptr = shared_gmem_ptr[row]; - if (gmem_ptr) { params_.user_param.red_op_(0, gmem_ptr, shared_elem_arr[row]); } + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + if (block_start_row_first_tile_ + row < extent_row_) { + user_params.red_op_(0, gmem_ptr + row, shared_elem_arr[row]); + } } __threadfence(); __syncthreads(); - if (threadIdx.x == 0 && thread_start_row_first_tile_ < extent_row_) { + if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { // release mutex lock. - volatile int* row_mutex = params_.user_param.mutexes_ + thread_start_row_first_tile_; - atomicCAS((int*)row_mutex, 1, 0); + atomicCAS(user_params.mutexes_ + mutex_id, 1, 0); } } } From 8a42b7784e6c615ff6479eeaae39aead8407551d Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 16 Mar 2023 08:09:29 -0700 Subject: [PATCH 21/48] use tile32 whenever group size is 32 as it uses optimal reduce, eliminate shfl of keys by using ballot based write by thread containing min val --- cpp/cmake/thirdparty/get_cutlass.cmake | 4 +- .../raft/distance/detail/fused_l2_nn.cuh | 2 + .../fused_l2_nn_epilogue_elementwise.cuh | 1 + .../predicated_tile_iterator_reduced_vec.h | 86 ++++++++++++++++--- 4 files changed, 81 insertions(+), 12 deletions(-) diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 3e02ce064e..c781563a2b 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -91,7 +91,7 @@ function(find_and_configure_cutlass) endfunction() if(NOT RAFT_CUTLASS_GIT_TAG) - set(RAFT_CUTLASS_GIT_TAG v2.9.1) + set(RAFT_CUTLASS_GIT_TAG v2.10.0) endif() if(NOT RAFT_CUTLASS_GIT_REPOSITORY) @@ -99,5 +99,5 @@ if(NOT RAFT_CUTLASS_GIT_REPOSITORY) endif() find_and_configure_cutlass( - VERSION 2.9.1 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG ${RAFT_CUTLASS_GIT_TAG} + VERSION 2.10.0 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG ${RAFT_CUTLASS_GIT_TAG} ) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 7771e8d29d..6ff2e2698a 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -278,6 +278,8 @@ struct kvp_cg_min_reduce_op { __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; + using AccTypeT = AccType; + using IndexT = Index; // functor signature. __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } __host__ __device__ AccType operator()(AccType a, AccType b) const { return a < b ? a : b; } diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh index 1d90e4f41d..b9ab2c0170 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh @@ -91,6 +91,7 @@ class FusedL2NNEpilogueElementwise { KVPReduceOpT_ pair_redop_; ReduceOpT_ red_op_; int* mutexes_; + using CGReduceT = CGReduceOp_; // // Methods // diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index d65df781b7..fc399932d2 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -85,6 +85,8 @@ class PredicatedTileIteratorReducedVec { using LongIndex = typename Layout::LongIndex; using TensorCoord = MatrixCoord; using EpilogueOpParams = EpilogueOpParams_; + using OutIdxT = typename EpilogueOpParams::CGReduceT::IndexT; + using OutValT = typename EpilogueOpParams::CGReduceT::AccTypeT; // static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; static int const kElementsPerAccess = 1; @@ -204,6 +206,62 @@ class PredicatedTileIteratorReducedVec { SharedStorage() {} }; + template + struct select_reduce { + /// Performs reduction and stores a reduced output to memory + CUTLASS_DEVICE + select_reduce(OutT red_value, reduce_op_t reduce_op, + cg_group_t cg_warp_group, OutT& shmem_ptr) + { + OutT reduced_val = cg::reduce(cg_warp_group, red_value, reduce_op); + if (cg_warp_group.thread_rank() == 0) { + shmem_ptr = reduced_val; + } + } + }; + + template + struct select_reduce > { + using ValT = float; + using Ty = raft::KeyValuePair; + + CUTLASS_DEVICE + select_reduce(Ty val_to_red, reduce_op_t reduce_op, + cg_group_t cg_warp_group, Ty & shmem_ptr) + { + ValT val = val_to_red.value; + ValT reduced_val = cg::reduce(cg_warp_group, val, reduce_op); + bool pred = (reduced_val == val); + auto subTile = cg::binary_partition(cg_warp_group, pred); + if (pred) { + if (subTile.thread_rank() == 0) { + shmem_ptr = val_to_red; + } + } + } + }; + + template + struct select_reduce > { + using ValT = double; + using Ty = raft::KeyValuePair; + + CUTLASS_DEVICE + select_reduce(Ty val_to_red, reduce_op_t reduce_op, + cg_group_t cg_warp_group, Ty & shmem_ptr) + { + ValT val = val_to_red.value; + ValT reduced_val = cg::reduce(cg_warp_group, val, reduce_op); + bool pred = (reduced_val == val); + auto subTile = cg::binary_partition(cg_warp_group, pred); + if (pred) { + if (subTile.thread_rank() == 0) { + shmem_ptr = val_to_red; + } + } + } + }; + private: // // Data members @@ -385,10 +443,13 @@ class PredicatedTileIteratorReducedVec { cg::thread_block cta = cg::this_thread_block(); cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); - - Element* shared_elem_arr = shared_storage_.data(); EpilogueOpParams const& user_params = params_.user_param; + using cg_reduce_t = decltype(user_params.cg_reduce_op); + using tile32_t = decltype(tile32); + + Element* shared_elem_arr = shared_storage_.data(); + static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * ThreadMap::Count::kTile; @@ -419,15 +480,20 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_(key_id, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_col_idx]); } } - bool col_guard = row_guard && mask_.predicates[0]; - auto subTile = cg::binary_partition(tile32, col_guard); + bool col_guard = row_guard && mask_.predicates[0]; + auto subTile = cg::binary_partition(tile32, col_guard); + using subTile_t = decltype(subTile); if (col_guard) { - (*frag_ptr)[frag_idx] = - cg::reduce(subTile, (*frag_ptr)[frag_idx], user_params.cg_reduce_op); - if (subTile.thread_rank() == 0) { - int iter_row = ((row_offset + thread_start_row_) % total_rows); - shared_elem_arr[iter_row] = (*frag_ptr)[frag_idx]; + int iter_row = ((row_offset + thread_start_row_) % total_rows); + if (subTile.size() == 32) { + select_reduce + red_obj((*frag_ptr)[frag_idx], user_params.cg_reduce_op, + tile32, shared_elem_arr[iter_row]); + } else { + select_reduce + red_obj((*frag_ptr)[frag_idx], user_params.cg_reduce_op, + subTile, shared_elem_arr[iter_row]); } } } @@ -449,7 +515,7 @@ class PredicatedTileIteratorReducedVec { for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_(0, gmem_ptr + row, shared_elem_arr[row]); + user_params.red_op_(0, &gmem_ptr[row], shared_elem_arr[row]); } } From 96e6e1ec610a3d78bf638798e357643698d566b3 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 16 Mar 2023 09:14:12 -0700 Subject: [PATCH 22/48] fix formatting issue --- .../raft/distance/detail/fused_l2_nn.cuh | 2 +- .../predicated_tile_iterator_reduced_vec.h | 68 ++++++++----------- 2 files changed, 31 insertions(+), 39 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 6ff2e2698a..cea52bb863 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -279,7 +279,7 @@ struct kvp_cg_min_reduce_op { __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; using AccTypeT = AccType; - using IndexT = Index; + using IndexT = Index; // functor signature. __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } __host__ __device__ AccType operator()(AccType a, AccType b) const { return a < b ? a : b; } diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index fc399932d2..e8654c9d42 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -210,54 +210,45 @@ class PredicatedTileIteratorReducedVec { struct select_reduce { /// Performs reduction and stores a reduced output to memory CUTLASS_DEVICE - select_reduce(OutT red_value, reduce_op_t reduce_op, - cg_group_t cg_warp_group, OutT& shmem_ptr) + select_reduce(OutT red_value, reduce_op_t reduce_op, cg_group_t cg_warp_group, OutT& shmem_ptr) { OutT reduced_val = cg::reduce(cg_warp_group, red_value, reduce_op); - if (cg_warp_group.thread_rank() == 0) { - shmem_ptr = reduced_val; - } + if (cg_warp_group.thread_rank() == 0) { shmem_ptr = reduced_val; } } }; template - struct select_reduce > { + struct select_reduce> { using ValT = float; - using Ty = raft::KeyValuePair; + using Ty = raft::KeyValuePair; CUTLASS_DEVICE - select_reduce(Ty val_to_red, reduce_op_t reduce_op, - cg_group_t cg_warp_group, Ty & shmem_ptr) + select_reduce(Ty val_to_red, reduce_op_t reduce_op, cg_group_t cg_warp_group, Ty& shmem_ptr) { - ValT val = val_to_red.value; - ValT reduced_val = cg::reduce(cg_warp_group, val, reduce_op); - bool pred = (reduced_val == val); - auto subTile = cg::binary_partition(cg_warp_group, pred); + ValT val = val_to_red.value; + ValT reduced_val = cg::reduce(cg_warp_group, val, reduce_op); + bool pred = (reduced_val == val); + auto subTile = cg::binary_partition(cg_warp_group, pred); if (pred) { - if (subTile.thread_rank() == 0) { - shmem_ptr = val_to_red; - } + if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } } } }; template - struct select_reduce > { + struct select_reduce> { using ValT = double; - using Ty = raft::KeyValuePair; + using Ty = raft::KeyValuePair; CUTLASS_DEVICE - select_reduce(Ty val_to_red, reduce_op_t reduce_op, - cg_group_t cg_warp_group, Ty & shmem_ptr) + select_reduce(Ty val_to_red, reduce_op_t reduce_op, cg_group_t cg_warp_group, Ty& shmem_ptr) { - ValT val = val_to_red.value; - ValT reduced_val = cg::reduce(cg_warp_group, val, reduce_op); - bool pred = (reduced_val == val); - auto subTile = cg::binary_partition(cg_warp_group, pred); + ValT val = val_to_red.value; + ValT reduced_val = cg::reduce(cg_warp_group, val, reduce_op); + bool pred = (reduced_val == val); + auto subTile = cg::binary_partition(cg_warp_group, pred); if (pred) { - if (subTile.thread_rank() == 0) { - shmem_ptr = val_to_red; - } + if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } } } }; @@ -441,8 +432,8 @@ class PredicatedTileIteratorReducedVec { { AccessType* frag_ptr = reinterpret_cast(&frag); - cg::thread_block cta = cg::this_thread_block(); - cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); + cg::thread_block cta = cg::this_thread_block(); + cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); EpilogueOpParams const& user_params = params_.user_param; using cg_reduce_t = decltype(user_params.cg_reduce_op); @@ -480,20 +471,21 @@ class PredicatedTileIteratorReducedVec { user_params.red_op_(key_id, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_col_idx]); } } - bool col_guard = row_guard && mask_.predicates[0]; - auto subTile = cg::binary_partition(tile32, col_guard); - using subTile_t = decltype(subTile); + bool col_guard = row_guard && mask_.predicates[0]; + auto subTile = cg::binary_partition(tile32, col_guard); + using subTile_t = decltype(subTile); if (col_guard) { int iter_row = ((row_offset + thread_start_row_) % total_rows); if (subTile.size() == 32) { - select_reduce - red_obj((*frag_ptr)[frag_idx], user_params.cg_reduce_op, - tile32, shared_elem_arr[iter_row]); + select_reduce red_obj( + (*frag_ptr)[frag_idx], user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); } else { - select_reduce - red_obj((*frag_ptr)[frag_idx], user_params.cg_reduce_op, - subTile, shared_elem_arr[iter_row]); + select_reduce red_obj( + (*frag_ptr)[frag_idx], + user_params.cg_reduce_op, + subTile, + shared_elem_arr[iter_row]); } } } From a4e45be5f54fa44e92bc4aa601f592b9166d6f43 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 17 Mar 2023 01:29:52 -0700 Subject: [PATCH 23/48] use ops::l2_exp_cutlass_op from updated changes --- cpp/include/raft/distance/detail/fused_l2_nn.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index cea52bb863..e0b41e6002 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -18,7 +18,7 @@ #include #include -#include +#include #include #include #include @@ -327,7 +327,7 @@ void fusedL2NNImpl(OutT* min, const auto deviceVersion = getComputeCapability(); if (deviceVersion.first >= 8) { - using L2Op = L2ExpandedOp; + using L2Op = raft::distance::detail::ops::l2_exp_cutlass_op; using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; kvp_cg_min_reduce_op_ cg_reduce_op; L2Op L2_dist_op(sqrt); From 7fa45e11006386b284019b0bd8e4f770c5f7cf12 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 20 Apr 2023 11:41:12 -0700 Subject: [PATCH 24/48] persistent cutlass version based on grouped gemm though only using its launcher, smem based norm vector --- .../detail/custom_epilogue_with_broadcast.h | 840 ++++++++++++++++++ .../raft/distance/detail/fused_l2_nn.cuh | 2 +- .../detail/fused_l2_nn_cutlass_base.cuh | 59 +- .../distance/detail/fused_l2_nn_epilogue.cuh | 14 +- .../fused_l2_nn_epilogue_elementwise.cuh | 5 +- .../raft/distance/detail/fused_l2_nn_gemm.h | 31 +- .../detail/fused_l2_nn_gemm_grouped_custom.h | 575 ++++++++++++ .../detail/predicated_tile_iterator_normvec.h | 13 +- .../predicated_tile_iterator_normvec_smem.h | 622 +++++++++++++ .../predicated_tile_iterator_reduced_vec.h | 257 +++--- 10 files changed, 2290 insertions(+), 128 deletions(-) create mode 100755 cpp/include/raft/distance/detail/custom_epilogue_with_broadcast.h create mode 100644 cpp/include/raft/distance/detail/fused_l2_nn_gemm_grouped_custom.h create mode 100755 cpp/include/raft/distance/detail/predicated_tile_iterator_normvec_smem.h diff --git a/cpp/include/raft/distance/detail/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/custom_epilogue_with_broadcast.h new file mode 100755 index 0000000000..6852a4b447 --- /dev/null +++ b/cpp/include/raft/distance/detail/custom_epilogue_with_broadcast.h @@ -0,0 +1,840 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#include +#else +#include +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template < + typename ElementC_, + typename ElementAccumulator_, + typename ElementCompute_, + typename ElementZ_, + typename ElementT_, + int ElementsPerAccess, + bool StoreZ = true, + bool StoreT = true +> +struct EpilogueWithBroadcastOpBaseCustom { + + using ElementOutput = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using FragmentT = Array; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = StoreZ; + + /// If true, the 'T' tensor is stored + static bool const kStoreT = StoreT; + + /// Parameters structure - required + struct Params { }; + + // + // Methods + // + + /// Constructor from Params + EpilogueWithBroadcastOpBaseCustom(Params const ¶ms_) { } + + /// Determine if the source is needed. May return false if + bool is_source_needed() const { + return true; + } + + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { } + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()( + FragmentZ &frag_Z, + FragmentT &frag_T, + FragmentAccumulator const &AB, + FragmentC const &frag_C, + FragmentCompute const &V) const { + + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()( + FragmentZ &frag_Z, + FragmentT &frag_T, + FragmentAccumulator const &AB, + FragmentCompute const &V) const { + + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator with bias vector broadcast over columns. +/// +/// Computes the following: +/// +/// +/// Z, T = OutputOp(AB, C, Broadcast) +/// +/// if (ElementwiseOp::kStoreZ) { +/// store(converted_u); +/// } +/// +/// if (ElementwiseOp::kStoreT) { +/// store(v); +/// } +/// +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) + typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) + typename ElementVector_, ///< Pointer to broadcast vector + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value) +> +class EpilogueWithBroadcastCustom : + public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition> { + +public: + + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using TensorTileIterator = TensorTileIterator_; + using ElementVector = ElementVector_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Compute data type produced by the output op + using ElementCompute = typename OutputOp::ElementCompute; + + /// Compute fragment + using FragmentCompute = Array; + + /// Thread map used by output tile iterators + using ThreadMap = typename OutputTileIterator::ThreadMap; + + /// Fragment object used to store the broadcast values + using BroadcastFragment = Array< + ElementCompute, + ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Data type of additional tensor + using ElementTensor = typename TensorTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Array type used by output functor + using ComputeAccessType = Array; + + /// Tensor access type + using TensorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + /// Shared memory allocation from epilogue base class + using BaseSharedStorage = typename Base::SharedStorage; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + /// Used for the broadcast + struct BroadcastDetail { + + /// Number of threads per warp + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar column indices handled by each thread + static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar row indices handled by each thread + static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; + + /// Number of threads per threadblock + static int const kThreadCount = kWarpSize * WarpCount::kCount; + + /// Number of distinct threads per row of output tile + static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); + + /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. + static int const kThreadRows = kThreadCount / kThreadsPerRow; + + /// I'm not sure what I meant here. + static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = MatrixShape< + kThreadRows, + Shape::kN + >; + + /// Debug printing + CUTLASS_DEVICE + static void print() { +#if 0 + printf("BroadcastDetail {\n"); + printf( + " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" + "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", + kColumnsPerThread, + kRowsPerThread, + kThreadCount, + kThreadsPerRow, + kThreadRows, + kThreadAccessesPerRow, + StorageShape::kRow, + StorageShape::kColumn, + StorageShape::kCount + ); + printf("};\n"); +#endif + } + }; + + /// Shared storage structure (shadows base) with additional SMEM buffer for reduction + struct SharedStorage { + union { + BaseSharedStorage base; + }; + + CUTLASS_HOST_DEVICE + SharedStorage() { } + }; + +public: + + + // static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, + // "Mismatch between shared load iterator and output tile iterator."); + + static_assert(SharedLoadIterator::Fragment::kElements == TensorTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +private: + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + /// Thread index within the threadblock + int thread_idx_; + +public: + + /// Constructor + CUTLASS_DEVICE + EpilogueWithBroadcastCustom( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + Base(shared_storage.base, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.base.reference(), thread_idx), + thread_idx_(thread_idx) + { + + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + ElementVector const * broadcast_ptr, ///< Broadcast vector + //OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix + TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand + MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord(Shape::kM, Shape::kN), + MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space + MatrixCoord()) { + + BroadcastFragment broadcast_fragment; + + load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); +#if 0 + if (!output_op.is_source_needed()) { + compute_source_not_needed_( + output_op, + broadcast_fragment, + destination_iterator, + accumulators, + tensor_iterator); + } + else { +#endif + compute_source_needed_( + output_op, + broadcast_fragment, + //destination_iterator, + accumulators, + source_iterator, + tensor_iterator); + //} + } + +private: + + CUTLASS_DEVICE + void load_broadcast_fragment_( + BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + ElementVector const * broadcast_ptr, ///< Broadcast vector + MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space + ) { + + broadcast_fragment.clear(); + + // If no pointer is supplied, set with all zeros and avoid memory accesses + if (!broadcast_ptr) { + return; + } + + int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); + + int thread_column_idx = threadblock_offset.column() + thread_initial_column; + broadcast_ptr += thread_initial_column; + + NumericArrayConverter converter; + using AccessType = AlignedArray; + using ComputeFragmentType = Array; + + ComputeFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { + + AccessType loaded; + + loaded.clear(); + + if (thread_column_idx < problem_size.column()) { + loaded = *reinterpret_cast(broadcast_ptr); + } + + ComputeFragmentType cvt = converter(loaded); + frag_ptr[j] = cvt; + + thread_column_idx += ThreadMap::Delta::kColumn; + broadcast_ptr += ThreadMap::Delta::kColumn; + } + } + + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * + (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const &output_op, ///< Output operator + BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) { +#if 0 + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + + // CUTLASS_PRAGMA_UNROLL + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) { + + // + // Convert and store fragment + // + + + __syncthreads(); + + acc2smem_source_not_needed< + cutlass::make_index_sequence>::push(iter, + accum_fragment_iterator, + this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } + else if (kPartitionsK > 1) { + + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Apply output operation + // + + typename OutputTileIterator::Fragment frag_Z; + typename TensorTileIterator::Fragment frag_T; + + apply_output_operator_source_not_needed_( + frag_Z, + frag_T, + output_op, + aligned_accum_fragment[0], + broadcast_fragment); + + // + // Conditionally store fragments + // + + if (OutputOp::kStoreZ) { + destination_iterator.store(frag_Z); + ++destination_iterator; + } + + if (OutputOp::kStoreT) { + tensor_iterator.store(frag_T); + ++tensor_iterator; + } + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } +#endif + } + + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE + static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const &output_op, ///< Output operator + BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + //OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) { + + typename OutputTileIterator::Fragment source_fragment; + source_fragment.clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Load the source + // + + // source_iterator.load(source_fragment); + // ++source_iterator; + + // + // Convert and store fragment + // + + //__syncthreads(); + + acc2smem_source_needed>::push( + iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); +#if 0 + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + if (kPartitionsK > 1) + { + plus add_fragments; + const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); + } +#endif + // + // Apply output operation + // + + typename OutputTileIterator::Fragment frag_Z; + typename TensorTileIterator::Fragment frag_T; + + // + // Load the source + // + + source_iterator.load(source_fragment); + ++source_iterator; + + apply_output_operator_( + frag_Z, + frag_T, + output_op, + aligned_accum_fragment[0], + source_fragment, + broadcast_fragment); + + // + // Conditionally store fragments + // +#if 0 + if (OutputOp::kStoreZ) { + destination_iterator.store(frag_Z); + ++destination_iterator; + } +#endif + if (OutputOp::kStoreT) { + tensor_iterator.store(frag_T); + ++tensor_iterator; + } + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment &frag_Z, + typename TensorTileIterator::Fragment &frag_T, + OutputOp const &output_op, + typename SharedLoadIterator::Fragment const &frag_AB, + typename OutputTileIterator::Fragment const &frag_C, + BroadcastFragment const &frag_Broadcast) { + + using AccessTypeZ = Array; + using AccessTypeT = Array; + using AccessTypeBroadcast = Array; + + AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); + AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); + + AccumulatorAccessType const *frag_AB_ptr = + reinterpret_cast(&frag_AB); + + OutputAccessType const *frag_C_ptr = + reinterpret_cast(&frag_C); + + AccessTypeBroadcast const *frag_Broadcast_ptr = + reinterpret_cast(&frag_Broadcast); + + // int const kOutputOpIterations = + // OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + int const kOutputOpIterations = + TensorTileIterator::Fragment::kElements / TensorTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + output_op( + frag_Z_ptr[i], + frag_T_ptr[i], + frag_AB_ptr[i], + frag_C_ptr[(i / ThreadMap::Iterations::kColumn)], + frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + typename OutputTileIterator::Fragment &frag_Z, + typename TensorTileIterator::Fragment &frag_T, + OutputOp const &output_op, + typename SharedLoadIterator::Fragment const &frag_AB, + BroadcastFragment const &frag_Broadcast) { + + using AccessTypeZ = Array; + using AccessTypeT = Array; + using AccessTypeBroadcast = Array; + + AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); + AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); + + AccumulatorAccessType const *frag_AB_ptr = + reinterpret_cast(&frag_AB); + + AccessTypeBroadcast const *frag_Broadcast_ptr = + reinterpret_cast(&frag_Broadcast); + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + output_op( + frag_Z_ptr[i], + frag_T_ptr[i], + frag_AB_ptr[i], + frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index e0b41e6002..25b325b1d5 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -67,7 +67,6 @@ struct MinAndDistanceReduceOpImpl { DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } DI void init(KVP* out, DataT maxVal) const { - out->key = 0; out->value = maxVal; } @@ -275,6 +274,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, template struct kvp_cg_min_reduce_op { typedef typename raft::KeyValuePair KVP; + //static const AccType maxVal; maxVal(std::numeric_limits::max()) __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh index bd910c0240..be20a34ae0 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh @@ -26,10 +26,10 @@ #endif #include - #include #include #include +#include #include #include @@ -39,6 +39,7 @@ #include #include #include +#include namespace raft { namespace distance { @@ -77,15 +78,13 @@ void cutlassFusedL2NNKernel(const DataT* x, DataT, // ElementCompute_ AccT, // ElementZ_ OutT, // ElementT_ - 1, // Elements per access 1 + 1, //128 / cutlass::sizeof_bits::value, // Elements per access 1 DistanceFn, FinalLambda, ReduceOpT, KVPReduceOpT>; constexpr int batch_count = 1; - constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; - typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op, redOp, pairRedOp, mutexes); const DataT *a, *b; @@ -98,7 +97,8 @@ void cutlassFusedL2NNKernel(const DataT* x, constexpr int Alignment = VecLen; // default initialize problem size with row major inputs - auto problem_size = cutlass::gemm::GemmCoord(n, m, k); + //auto problem_size = cutlass::gemm::GemmCoord(n, m, k); + auto problem_size = cutlass::gemm::GemmCoord(m, n, k); constexpr bool isRowMajor = true; @@ -113,12 +113,14 @@ void cutlassFusedL2NNKernel(const DataT* x, NumStages, // Number of pipeline stages isRowMajor>::GemmKernel; +#if 0 using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; a = y; b = x; gemm_lda = ldb; gemm_ldb = lda; + constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; typename cutlassDist::Arguments arguments{ mode, @@ -157,6 +159,53 @@ void cutlassFusedL2NNKernel(const DataT* x, RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); // Launch initialized CUTLASS kernel RAFT_CUTLASS_TRY(cutlassDist_op()); +#else + + + using cutlassDist = cutlass::gemm::device::GemmGrouped; + + a = x; + b = y; + gemm_lda = lda; + gemm_ldb = ldb; + int num_blocks = cutlassDist::maximum_active_blocks(); + int num_sms = raft::getMultiProcessorCount(); + num_blocks = num_blocks * num_sms; + auto thread_blocks = std::max(num_blocks, int((problem_size.m() - 1 + cutlassDistKernel::Mma::Shape::kM)/ cutlassDistKernel::Mma::Shape::kM)); + //printf("num blocks = %d sms = %d thread_blocks_sel = %d shapekM = %d\n", num_blocks, num_sms, (int)thread_blocks, (int)cutlassDistKernel::Mma::Shape::kM); + //rmm::device_uvector problem_sizes(sizeof(decltype(problem_size)), stream); + //raft::copy(problem_sizes.data(), &problem_size, 1, stream); + typename cutlassDist::Arguments arguments{ + //problem_sizes.data(), + problem_size, + batch_count, + thread_blocks, + epilog_op_param, + a, + b, + xn, // C matrix eq vector param, which here is A norm + (DataT*)yn, // this is broadcast vec, which is required to be non-const param + dOutput, // Output distance matrix + (int64_t)gemm_lda, // stride A + (int64_t)gemm_ldb, // stride B + (int64_t)1, // stride A norm + (int64_t)ldd // stride Output matrix + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = cutlassDist::get_workspace_size(arguments); + // Allocate workspace memory + rmm::device_uvector workspace(workspace_size, stream); + // Instantiate CUTLASS kernel depending on templates + cutlassDist cutlassDist_op; + // Check the problem size is supported or not + RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); + // Initialize CUTLASS kernel with arguments and workspace pointer + RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); + // Launch initialized CUTLASS kernel + RAFT_CUTLASS_TRY(cutlassDist_op.run(stream)); +#endif + } }; // namespace detail diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh index 8d94de3378..282ac7e906 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh @@ -39,9 +39,11 @@ operation. #include #include #include -#include +//#include +#include -#include +//#include +#include #include //////////////////////////////////////////////////////////////////////////////// @@ -71,8 +73,10 @@ struct FusedL2NNEpilogue { // // Stores the result z = (y = GEMM(A, B, C), broadcast) // - using RowNormTileIterator = cutlass::epilogue::threadblock:: - PredicatedTileIteratorNormVec; + // using RowNormTileIterator = cutlass::epilogue::threadblock:: + // PredicatedTileIteratorNormVec; + using RowNormTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorNormVecSmem< + typename Base::OutputTileThreadMap, ElementOutput, LayoutT>; // // Additional tensor tile iterator - stores t = Elementwise(z) @@ -84,7 +88,7 @@ struct FusedL2NNEpilogue { typename OutputOp::Params>; /// Define the epilogue - using Epilogue = EpilogueWithBroadcast #include +#include #include ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -59,12 +60,17 @@ struct FusedL2NNGemm { /// Threadblock-level tile size (concept: GemmShape) // <- threadblock tile M = 32, N = 64, K = 16 - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; // this is more performant + //using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 16>; // this is more performant for grouped GEMM + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; // this is more performant for non-grouped GEMM + //using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; // SHAPE for less reg pressure grouped GEMM /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute // <- warp tile M = 64, N = 64, K = 16 - using WarpShape = cutlass::gemm::GemmShape<16, 64, 16>; // this is more performant + //using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // this is more performant for grouped GEMM + //using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; // // SHAPE for less reg pressure grouped GEMM + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // // SHAPE for less reg pressure grouped GEMM + //using WarpShape = cutlass::gemm::GemmShape<16, 64, 16>; // // this is more performant for non-grouped GEMM /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op @@ -73,6 +79,7 @@ struct FusedL2NNGemm { /// Operation performed by GEMM using Operator = cutlass::arch::OpMultiplyAddFastF32; + //using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU // SM @@ -83,7 +90,8 @@ struct FusedL2NNGemm { // This code section describes how threadblocks are scheduled on GPU /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>; + //using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>; + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; /// data layout for final output matrix. // we keep this same layout even for column major inputs @@ -132,9 +140,12 @@ struct FusedL2NNGemm { NormXLayout, GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + // Compose the GEMM kernel - using GemmKernel = - FusedL2NNWithFusedEpilogue; + // using GemmKernel = + // FusedL2NNWithFusedEpilogue; + using GemmKernel = FusedL2NNWithGemmGrouped; }; template < @@ -164,10 +175,12 @@ struct FusedL2NNGemm; + //using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute // <- warp tile M = 32, N = 32, K = 16 using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + //using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; @@ -185,6 +198,7 @@ struct FusedL2NNGemm; + /// data layout for final output matrix. // we keep this same layout even for column major inputs using LayoutOutput = cutlass::layout::RowMajor; @@ -233,8 +247,11 @@ struct FusedL2NNGemm::Epilogue; // Compose the GEMM kernel - using GemmKernel = - FusedL2NNWithFusedEpilogue; + // using GemmKernel = + // FusedL2NNWithFusedEpilogue; + using GemmKernel = FusedL2NNWithGemmGrouped; + }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm_grouped_custom.h b/cpp/include/raft/distance/detail/fused_l2_nn_gemm_grouped_custom.h new file mode 100644 index 0000000000..7ee01fa6fe --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_l2_nn_gemm_grouped_custom.h @@ -0,0 +1,575 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Problem visitor for grouped GEMMs +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FusedL2NNWithGemmGrouped { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = GemmGroupedProblemVisitor; + + // + // Structures + // + + struct temp_problem_visitor { + int problem_count; + + CUTLASS_HOST_DEVICE temp_problem_visitor() : problem_count(0) {}; + CUTLASS_HOST_DEVICE temp_problem_visitor(int problem_count_) : problem_count(problem_count_){}; + }; + + /// Argument structure + struct Arguments { + // + // Data members + // + GemmCoord problem_sizes; + temp_problem_visitor problem_visitor; + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + void const* ptr_A; + void const* ptr_B; + void const* ptr_C; + void* ptr_Vector; + void* ptr_Tensor; + + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldt; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : //problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_Vector(nullptr), + ptr_Tensor(nullptr), + lda(0), + ldb(0), + ldc(0), + ldt(0), + host_problem_sizes(nullptr) + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord problem_sizes, + int problem_count, + int threadblock_count, + typename EpilogueOutputOp::Params output_op, + void const* ptr_A, + void const* ptr_B, + void const* ptr_C, + void* ptr_Vector, + void* ptr_Tensor, + typename LayoutA::Stride::Index lda, + typename LayoutB::Stride::Index ldb, + typename LayoutC::Stride::Index ldc, + typename LayoutC::Stride::Index ldt, + GemmCoord* host_problem_sizes = nullptr) + : problem_sizes(problem_sizes), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + lda(lda), + ldb(ldb), + ldc(ldc), + ldt(ldt), + host_problem_sizes(host_problem_sizes) + { + problem_visitor.problem_count = problem_count; + } + + + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + //typename ProblemVisitor::Params problem_visitor; + temp_problem_visitor problem_visitor; + int threadblock_count; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::TensorTileIterator::Params params_Tensor; + + typename EpilogueOutputOp::Params output_op; + + void* ptr_A; + void* ptr_B; + void* ptr_C; + void* ptr_Vector; + void* ptr_Tensor; + + GemmCoord problem_size; + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldt; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : params_A(0), + params_B(0), + params_C(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_Vector(nullptr), + ptr_Tensor(nullptr), + lda(0), + ldb(0), + ldc(0), + ldt(0) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : //problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), + problem_size(args.problem_sizes), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + // Here we pass additional user args via args.output_op + // to the reduction output tile iterator + params_Tensor(args.ldt, args.output_op), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_Vector(args.ptr_Vector), + ptr_Tensor(args.ptr_Tensor), + lda(args.lda), + ldb(args.ldb), + ldc(args.ldc), + ldt(args.ldt) + { + problem_visitor.problem_count = args.problem_visitor.problem_count; + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + { + // problem_visitor = typename ProblemVisitor::Params( + // args.problem_sizes, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_Vector = args.ptr_Vector; + ptr_Tensor = args.ptr_Tensor; + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldt = args.ldt; + + problem_size = args.problem_sizes; + } + }; + + struct epilogue_SharedStorage { + typename Epilogue::SharedStorage epilogue; + //typename Epilogue::TensorTileIterator::SharedStorage reduced_store; + }; + + /// Shared memory storage structure + struct SharedStorage { + union { + typename Mma::SharedStorage main_loop; + epilogue_SharedStorage epilogue_combined_store; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + //typename ProblemVisitor::SharedStorage problem_visitor; + typename Epilogue::TensorTileIterator::SharedStorage reduced_store; + typename Epilogue::OutputTileIterator::SharedStorage rownorm_store; + + }; + + protected: + //uint32_t tile_idx; + public: + // + // Methods + // + + CUTLASS_DEVICE + FusedL2NNWithGemmGrouped() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { return Status::kSuccess; } + + static size_t get_extra_workspace_size(Arguments const& args, + cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + return 0; + } + + CUTLASS_DEVICE + static uint32_t tile_count_(const cutlass::MatrixCoord& grid) { + return grid.row() * grid.column(); + } + + /// Get the grid shape + CUTLASS_DEVICE + static cutlass::MatrixCoord grid_shape_(const cutlass::gemm::GemmCoord& problem) { + + return cutlass::MatrixCoord( + ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN)); + } + + CUTLASS_DEVICE + bool custom_next_tile_(const cutlass::gemm::GemmCoord &problem_size, uint32_t tile_idx_) { + // Check whether the tile to compute is within the range of the current problem. + const auto grid = grid_shape_(problem_size); + const uint32_t problem_chunk = (tile_count_(grid) - 1 + gridDim.x) / gridDim.x; + const uint32_t problem_chunk_end = blockIdx.x * problem_chunk + problem_chunk; + if (tile_idx_ < problem_chunk_end) { + return true; + } + + return false; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + using ElementOut = typename Epilogue::TensorTileIterator::Element; + using LongIndexOut = typename Epilogue::TensorTileIterator::LongIndex; + using OutValTy = typename Epilogue::TensorTileIterator::OutValT; + // + // Problem visitor. + // + // ProblemVisitor problem_visitor( + // params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const GemmCoord& problem_size = params.problem_size; + const uint32_t problem_chunk = (tile_count_(grid_shape_(problem_size)) - 1 + gridDim.x) / gridDim.x; + const uint32_t problem_chunk_end = blockIdx.x * problem_chunk + problem_chunk; + const auto grid_shape = grid_shape_(problem_size); + typename LayoutB::Index column = ((blockIdx.x * problem_chunk) % grid_shape.column()) * Mma::Shape::kN; + { + ElementOut* shared_elem_arr_ = shared_storage.reduced_store.data(); + constexpr auto maxVal_ = std::numeric_limits::max(); + + if (column) { + for (int row = threadIdx.x; row < Mma::Shape::kM; row += blockDim.x) { + params.output_op.red_op_.init(&shared_elem_arr_[row], maxVal_); + } + } + } + + { + ElementC* shared_elem_arr = shared_storage.rownorm_store.data(); + if (column) { + typename LayoutB::Index row = ((blockIdx.x * problem_chunk) / grid_shape.column()) * Mma::Shape::kM; + + uint8_t* first_tile_byte_pointer_ = reinterpret_cast(params.ptr_C) + + typename LayoutB::LongIndex(row) * typename LayoutB::LongIndex(sizeof(ElementC)); + auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + for (int row_local = threadIdx.x ; row_local < Mma::Shape::kM; row_local += blockDim.x) { + bool guard = (row + row_local) < problem_size.m(); + cutlass::arch::cp_async(shared_elem_arr + row_local, gmem_ptr + row_local, guard); + cutlass::arch::cp_async_wait<0>(); + } + } + } + + // Outer 'persistent' loop to iterate over tiles + for (uint32_t tile_idx = blockIdx.x * problem_chunk; tile_idx < problem_chunk_end; tile_idx++) { + + const auto grid_shape = grid_shape_(problem_size); + cutlass::MatrixCoord threadblock_offset( + int(tile_idx / grid_shape.column()) * Mma::Shape::kM, + int(tile_idx % grid_shape.column()) * Mma::Shape::kN); +#if 1 + //const bool isNextTile = custom_next_tile_(problem_size, tile_idx + 1); + const bool isNextTile = ((tile_idx + 1) < problem_chunk_end); + //const bool doesRowChange = ((int((tile_idx + 1) / grid_shape.column()) * Mma::Shape::kM) == threadblock_offset.row()); + const bool doesRowChange = ((threadblock_offset.column() + Mma::Shape::kN) >= problem_size.n()); + const bool do_gmem_reduce = (doesRowChange || !isNextTile) ? true : false; +#endif + // Load element pointers. Exchange pointers and strides if working on the transpose + //const ElementA* ptr_A = reinterpret_cast((kTransposed ? params.ptr_B : params.ptr_A)); + //typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb : params.lda); + + //const ElementB* ptr_B = reinterpret_cast((kTransposed ? params.ptr_A : params.ptr_B)); + //typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda : params.ldb); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{threadblock_offset.row(), 0}; + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.column()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size.k(), problem_size.n()}, thread_idx, tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + //__syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = static_cast(params.ptr_C); + typename Epilogue::ElementTensor* ptr_Tensor = + static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector* ptr_Vector = + static_cast(params.ptr_Vector); + + // Tile iterator loading from source tensor. +#if 1 + typename Epilogue::OutputTileIterator iterator_rownorm( + shared_storage.rownorm_store, + params.params_C, ptr_C, problem_size.mn(), thread_idx, + threadblock_offset); +#else + typename Epilogue::OutputTileIterator iterator_rownorm( + params.params_C, ptr_C, problem_size.mn(), thread_idx, + threadblock_offset); +#endif + + // Tile iterator writing to destination tensor. + // typename Epilogue::OutputTileIterator::Params params_D(0); + // ElementC* ptr_D = nullptr; +#if 1 + // typename Epilogue::OutputTileIterator iterator_D( + // shared_storage.rownorm_store, + // params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset); +#else + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset); +#endif + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator( + shared_storage.reduced_store, + params.params_Tensor, + // Only the final block outputs Tensor + ptr_Tensor, + problem_size.mn(), + thread_idx, + do_gmem_reduce, + threadblock_offset); + + Epilogue epilogue(shared_storage.kernel.epilogue_combined_store.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column(); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + ptr_Vector, + //iterator_D, + accumulators, + iterator_rownorm, + tensor_iterator, + problem_size.mn(), + threadblock_offset); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h index 67c01448dc..c977d491c0 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h @@ -284,11 +284,14 @@ class PredicatedTileIteratorNormVec { CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[0], - guard); + if (column == 0) { + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[0], + guard); + } else { + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn]; + } } if (row + 1 < ThreadMap::Iterations::kRow) { diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec_smem.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec_smem.h new file mode 100755 index 0000000000..1c62f1a061 --- /dev/null +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec_smem.h @@ -0,0 +1,622 @@ +/* + * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) + +Changes: +- added `Layout_` template param +- Only the row index is used to load the data in load_with_byte_offset(). + This way the same normalization data is used across all columns in a row. + +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template +class PredicatedTileIteratorNormVecSmem { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = Layout_; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + kIterations; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + +// static_assert((ThreadMap::Iterations::kRow == 1) || (ThreadMap::Iterations::kRow == 2) +// || (ThreadMap::Iterations::kRow == 4) , "ThreadMap::Iterations::kRow must be 1, 2 or 4"); + /// Fragment object + // using Fragment = Array; + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + /// Shared storage allocation needed by the predicated tile + // iterator for storing rowNorm chunk of di. + struct SharedStorage { + // + // Type definitions + // + using Shape = MatrixShape; + + /// Shape of the shared memory allocation for the reduced values store + using StorageShape = MatrixShape; + + // + // Data members + + + // + // Methods + // + AlignedBuffer storage; + + CUTLASS_DEVICE + Element* data() { return storage.data(); } + + SharedStorage() {} + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + //uint8_t* first_tile_byte_pointer_; + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + //Index block_start_row_first_tile_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + + private: + // + // Methods + // + + protected: + SharedStorage& shared_storage_; + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorNormVecSmem(SharedStorage& shared_storage, + PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + //const bool init_shmem, + TensorCoord& threadblock_offset, + int const* indices = nullptr) + : params_(params), indices_(indices), shared_storage_(shared_storage) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + return; + } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride); + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + if (threadblock_offset.column() == 0) { + Element* shared_elem_arr = shared_storage_.data(); + uint8_t* first_tile_byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(threadblock_offset.row()) * LongIndex(params_.stride); + auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + bool guard = (threadblock_offset.row() + row) < extent_row_; + cutlass::arch::cp_async(shared_elem_arr + row, gmem_ptr + row, guard); + cutlass::arch::cp_async_wait<0>(); + } + //__syncthreads(); + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const + { + AccessType* frag_ptr = reinterpret_cast(&frag); + + Element* shared_elem_arr = shared_storage_.data(); + +#if 0 + Element row_vals[ThreadMap::Iterations::kRow]; + //static int constexpr ldsPerAccess = sizeof(Element) == 8 ? 2 : ThreadMap::Iterations::kRow; + int iter_row_ = ((thread_start_row_) % total_rows); + raft::lds(row_vals, shared_elem_arr + iter_row_); +#endif + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int iter_row = ((row_offset + thread_start_row_) % total_rows); + (*frag_ptr)[frag_row_idx] = shared_elem_arr[iter_row]; + + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { byte_pointer += params_.increment_row; } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset(Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const + { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) row_add_P = 0; + if (output_Q > convolution_Q - 2) row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorNormVecSmem& operator++() + { + ++state_[0]; + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h index e8654c9d42..69b201c451 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h @@ -88,8 +88,8 @@ class PredicatedTileIteratorReducedVec { using OutIdxT = typename EpilogueOpParams::CGReduceT::IndexT; using OutValT = typename EpilogueOpParams::CGReduceT::AccTypeT; - // static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kElementsPerAccess = 1; + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + //static int const kElementsPerAccess = 1; static int const kThreads = ThreadMap::kThreads; static int const kIterations = ThreadMap::Count::kTile; @@ -99,6 +99,10 @@ class PredicatedTileIteratorReducedVec { static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); static_assert(!UseCUDAStore, "UseCUDAStore path is not supported"); + + static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::Count::kTile * ThreadMap::Delta::kRow; /// Fragment object using Fragment = Array; + using Shape = MatrixShape; /// Shape of the shared memory allocation for the reduced values store using StorageShape = MatrixShape; // // Data members - // - static const int warp_row_stride = - ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster; - static const int tile_row_stride = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * - ThreadMap::Iterations::kCluster; + // // Methods @@ -206,49 +202,68 @@ class PredicatedTileIteratorReducedVec { SharedStorage() {} }; - template + template struct select_reduce { /// Performs reduction and stores a reduced output to memory CUTLASS_DEVICE - select_reduce(OutT red_value, reduce_op_t reduce_op, cg_group_t cg_warp_group, OutT& shmem_ptr) + select_reduce(OutT value, cg_reduce_op_t reduce_op, + cg_group_t cg_warp_group, OutT& shmem_ptr) { - OutT reduced_val = cg::reduce(cg_warp_group, red_value, reduce_op); - if (cg_warp_group.thread_rank() == 0) { shmem_ptr = reduced_val; } + OutT element = reduce_op(shmem_ptr, value); + if (cg_warp_group.any(element == value)) { + OutT reduced_val = cg::reduce(cg_warp_group, value, reduce_op); + if (cg_warp_group.thread_rank() == 0) { + shmem_ptr = reduced_val; + } + } } }; - template - struct select_reduce> { + template + struct select_reduce> { using ValT = float; using Ty = raft::KeyValuePair; CUTLASS_DEVICE - select_reduce(Ty val_to_red, reduce_op_t reduce_op, cg_group_t cg_warp_group, Ty& shmem_ptr) + select_reduce(Ty val_to_red, cg_reduce_op_t cg_reduce_op, + cg_group_t cg_warp_group, Ty& shmem_ptr) { ValT val = val_to_red.value; - ValT reduced_val = cg::reduce(cg_warp_group, val, reduce_op); - bool pred = (reduced_val == val); - auto subTile = cg::binary_partition(cg_warp_group, pred); - if (pred) { - if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } + Ty element = cg_reduce_op(shmem_ptr, val_to_red); + if (cg_warp_group.any(element.value == val_to_red.value)) { + ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); + bool pred = (reduced_val == val); + auto subTile = cg::binary_partition(cg_warp_group, pred); + if (pred) { + if (subTile.thread_rank() == 0) { + shmem_ptr = val_to_red; + } + } } } }; - template - struct select_reduce> { + template + struct select_reduce> { using ValT = double; using Ty = raft::KeyValuePair; CUTLASS_DEVICE - select_reduce(Ty val_to_red, reduce_op_t reduce_op, cg_group_t cg_warp_group, Ty& shmem_ptr) + select_reduce(Ty val_to_red, cg_reduce_op_t cg_reduce_op, + cg_group_t cg_warp_group, Ty& shmem_ptr) { ValT val = val_to_red.value; - ValT reduced_val = cg::reduce(cg_warp_group, val, reduce_op); - bool pred = (reduced_val == val); - auto subTile = cg::binary_partition(cg_warp_group, pred); - if (pred) { - if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } + Ty element = cg_reduce_op(shmem_ptr, val_to_red); + if (cg_warp_group.any(element.value == val_to_red.value)) { + ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); + bool pred = (reduced_val == val); + auto subTile = cg::binary_partition(cg_warp_group, pred); + if (pred) { + if (subTile.thread_rank() == 0) { + shmem_ptr = val_to_red; + } + } } } }; @@ -284,7 +299,7 @@ class PredicatedTileIteratorReducedVec { /// Internal state counter int state_[3]; - mutable int shared_tile_id; + //mutable int shared_tile_id; /// Scatter indices int const* indices_; @@ -299,13 +314,15 @@ class PredicatedTileIteratorReducedVec { protected: SharedStorage& shared_storage_; + const bool& do_gmem_reduction_; private: // // Methods // - + //static OutValT const maxVal = std::numeric_limits::max(); public: + // // Methods // @@ -316,10 +333,13 @@ class PredicatedTileIteratorReducedVec { Element* pointer, TensorCoord extent, int thread_idx, + const bool& do_gmem_reduction, TensorCoord threadblock_offset = TensorCoord(), int const* indices = nullptr) - : params_(params), indices_(indices), shared_storage_(shared_storage) + : params_(params), indices_(indices), shared_storage_(shared_storage), + do_gmem_reduction_(do_gmem_reduction) { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; extent_row_ = extent.row(); @@ -330,8 +350,7 @@ class PredicatedTileIteratorReducedVec { TensorCoord block_offset = ThreadMap::initial_offset(0) + threadblock_offset; block_start_row_first_tile_ = block_offset.row(); - shared_tile_id = 0; - + // Initialize predicates CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { @@ -339,15 +358,22 @@ class PredicatedTileIteratorReducedVec { ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); } + if (threadblock_offset.column() == 0) { + Element* shared_elem_arr = shared_storage_.data(); + EpilogueOpParams const& user_params = params_.user_param; + constexpr auto maxVal = std::numeric_limits::max(); + + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + user_params.red_op_.init(&shared_elem_arr[row], maxVal); + } + } + // Null pointer performs no accesses if (!pointer) { mask_.clear(); } if (ScatterD && !indices) { mask_.clear(); } // Initialize pointer - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.row()) * LongIndex(params_.stride); - first_tile_byte_pointer_ = reinterpret_cast(pointer) + LongIndex(block_offset.row()) * LongIndex(params_.stride); @@ -360,6 +386,50 @@ class PredicatedTileIteratorReducedVec { state_[0] = state_[1] = state_[2] = 0; } + /// Destructor + CUTLASS_DEVICE + ~PredicatedTileIteratorReducedVec() { + if (do_gmem_reduction_) { + EpilogueOpParams const& user_params = params_.user_param; + + auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + Element* shared_elem_arr = shared_storage_.data(); + + // If this is not optimal grid size perform mutex based gmem reduce. + if ((gridDim.x != ((extent_row_ - 1 + Shape::kRow) / Shape::kRow))) { + const auto mutex_id = (block_start_row_first_tile_ / total_rows); + // single lock per block for multiple rows + if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { + // acquire mutex lock. + while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) + ; + } + + __syncthreads(); + + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + if (block_start_row_first_tile_ + row < extent_row_) { + user_params.red_op_(0, &gmem_ptr[row], shared_elem_arr[row]); + } + } + + __threadfence(); + __syncthreads(); + if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { + // release mutex lock. + atomicCAS(user_params.mutexes_ + mutex_id, 1, 0); + } + } else { + __syncthreads(); + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + if (block_start_row_first_tile_ + row < extent_row_) { + gmem_ptr[row] = shared_elem_arr[row]; + } + } + } + } + } + /// Adds a pointer offset in units of Element CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset) @@ -371,6 +441,7 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { +#if 0 uint8_t* byte_pointer = byte_pointer_; AccessType* frag_ptr = reinterpret_cast(&frag); @@ -401,11 +472,15 @@ class PredicatedTileIteratorReducedVec { CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { bool guard = row_guard && mask_.predicates[column]; + if (column == 0) { + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[0], + guard); + } else { + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn]; + } - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[0], - guard); } if (row + 1 < ThreadMap::Iterations::kRow) { @@ -420,6 +495,7 @@ class PredicatedTileIteratorReducedVec { byte_pointer += params_.increment_cluster; } } +#endif } /// Loads a fragment from memory @@ -434,17 +510,22 @@ class PredicatedTileIteratorReducedVec { cg::thread_block cta = cg::this_thread_block(); cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); + // constexpr int tile_width = (32 / ThreadMap::Delta::kColumn) ? 32 : 16; + // cg::thread_block_tile tile32 = cg::tiled_partition(cta); EpilogueOpParams const& user_params = params_.user_param; using cg_reduce_t = decltype(user_params.cg_reduce_op); using tile32_t = decltype(tile32); Element* shared_elem_arr = shared_storage_.data(); - - static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * - ThreadMap::Count::kTile; - + constexpr auto maxVal = std::numeric_limits::max(); + + // if (threadIdx.x == 0 && blockIdx.x == 0) { + // printf("\nIterations::kColumn = %d Iterations::kRow = %d Iterations::kGroup = %d Iterations::kCluster = %d kElementsPerAccess = %d\n", + // ThreadMap::Iterations::kColumn, ThreadMap::Iterations::kRow, ThreadMap::Iterations::kGroup, ThreadMap::Iterations::kCluster, kElementsPerAccess); + // printf("\nDelta::kColumn = %d Delta::kRow = %d Delta::kGroup = %d Delta::kCluster = %d kElementsPerAccess = %d tile_count = %d total_rows = %d\n", + // ThreadMap::Delta::kColumn, ThreadMap::Delta::kRow, ThreadMap::Delta::kGroup, ThreadMap::Delta::kCluster, kElementsPerAccess, kIterations, total_rows); + // } CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { CUTLASS_PRAGMA_UNROLL @@ -460,63 +541,29 @@ class PredicatedTileIteratorReducedVec { bool row_guard = ((row_offset + thread_start_row_) < extent_row_); const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - if (guard) { - const auto key_id = thread_start_column_ + ThreadMap::Delta::kColumn * column; - const int frag_col_idx = frag_idx + column; - user_params.red_op_.init_key((*frag_ptr)[frag_col_idx], key_id); - user_params.red_op_(key_id, &(*frag_ptr)[frag_idx], (*frag_ptr)[frag_col_idx]); + Element red_val; + user_params.red_op_.init(&red_val, maxVal); + if (row_guard) { + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + bool guard = mask_.predicates[column]; + if (guard) { + const auto key_id = thread_start_column_ + ThreadMap::Delta::kColumn * column; + const int frag_col_idx = frag_idx + column; + user_params.red_op_.init_key((*frag_ptr)[frag_col_idx], key_id); + user_params.red_op_(key_id, &red_val, (*frag_ptr)[frag_col_idx]); + } } - } - bool col_guard = row_guard && mask_.predicates[0]; - auto subTile = cg::binary_partition(tile32, col_guard); - using subTile_t = decltype(subTile); - - if (col_guard) { - int iter_row = ((row_offset + thread_start_row_) % total_rows); - if (subTile.size() == 32) { - select_reduce red_obj( - (*frag_ptr)[frag_idx], user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); - } else { - select_reduce red_obj( - (*frag_ptr)[frag_idx], - user_params.cg_reduce_op, - subTile, - shared_elem_arr[iter_row]); - } - } - } - } - } - // If this is last tile then perform reduction in gmem. - if (shared_tile_id == (ThreadMap::Count::kTile - 1)) { - const auto mutex_id = (block_start_row_first_tile_ / total_rows); - // single lock per block for multiple rows - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // acquire mutex lock. - while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) - ; - } - __syncthreads(); - - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_(0, &gmem_ptr[row], shared_elem_arr[row]); + const int iter_row = ((row_offset + thread_start_row_) % total_rows); + // if (blockIdx.x == 0) { + // printf("iter_row = %d thread_start_row_ = %d row_offset = %d tid = %d warp_id = %d\n", (int)iter_row, (int)thread_start_row_, (int)row_offset, (int)threadIdx.x, (int)threadIdx.x / 32); + // } + select_reduce red_obj( + red_val, user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); + } } } - - __threadfence(); - __syncthreads(); - if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { - // release mutex lock. - atomicCAS(user_params.mutexes_ + mutex_id, 1, 0); - } } } @@ -534,6 +581,7 @@ class PredicatedTileIteratorReducedVec { int add_Q, int problem_N) const { +#if 0 uint8_t* byte_pointer = byte_pointer_; AccessType* frag_ptr = reinterpret_cast(&frag); @@ -584,6 +632,7 @@ class PredicatedTileIteratorReducedVec { byte_pointer += params_.increment_cluster; } } +#endif } /// Loads a fragment from memory @@ -596,6 +645,7 @@ class PredicatedTileIteratorReducedVec { int add_Q, int problem_N) const { +#if 0 uint8_t* byte_pointer = byte_pointer_; AccessType* frag_ptr = reinterpret_cast(&frag); @@ -651,6 +701,7 @@ class PredicatedTileIteratorReducedVec { byte_pointer += params_.increment_cluster; } } +#endif } CUTLASS_DEVICE @@ -677,7 +728,7 @@ class PredicatedTileIteratorReducedVec { PredicatedTileIteratorReducedVec& operator++() { ++state_[0]; - shared_tile_id++; // tile iteration. + //shared_tile_id++; // tile iteration. if (!ScatterD) { byte_pointer_ += params_.advance_row; } From 087dbf0c016080dadab9dee5b4bd54fd165d1119 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 27 Apr 2023 07:45:54 -0700 Subject: [PATCH 25/48] add support for vectorized epilogue, move the sources to fused_distance_nn folder, polish as needed --- .../distance/detail/distance_ops/l2_exp.cuh | 4 +- .../custom_epilogue_with_broadcast.h | 77 +---- .../cutlass_base.cuh} | 122 ++------ .../epilogue.cuh} | 12 +- .../epilogue_elementwise.cuh} | 14 +- .../fusedL2NN_gemm_with_fused_epilogue.h | 0 .../gemm.h} | 35 +-- .../persistent_gemm.h} | 108 ++----- .../predicated_tile_iterator_normvec_smem.h | 221 +------------- .../predicated_tile_iterator_reduced_vec.h | 282 +++--------------- .../raft/distance/detail/fused_l2_nn.cuh | 20 +- 11 files changed, 163 insertions(+), 732 deletions(-) rename cpp/include/raft/distance/detail/{ => fused_distance_nn}/custom_epilogue_with_broadcast.h (91%) rename cpp/include/raft/distance/detail/{fused_l2_nn_cutlass_base.cuh => fused_distance_nn/cutlass_base.cuh} (53%) rename cpp/include/raft/distance/detail/{fused_l2_nn_epilogue.cuh => fused_distance_nn/epilogue.cuh} (87%) rename cpp/include/raft/distance/detail/{fused_l2_nn_epilogue_elementwise.cuh => fused_distance_nn/epilogue_elementwise.cuh} (93%) rename cpp/include/raft/distance/detail/{ => fused_distance_nn}/fusedL2NN_gemm_with_fused_epilogue.h (100%) rename cpp/include/raft/distance/detail/{fused_l2_nn_gemm.h => fused_distance_nn/gemm.h} (86%) rename cpp/include/raft/distance/detail/{fused_l2_nn_gemm_grouped_custom.h => fused_distance_nn/persistent_gemm.h} (82%) rename cpp/include/raft/distance/detail/{ => fused_distance_nn}/predicated_tile_iterator_normvec_smem.h (58%) rename cpp/include/raft/distance/detail/{ => fused_distance_nn}/predicated_tile_iterator_reduced_vec.h (62%) diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index fb00f8d66a..b45c30fa20 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -90,8 +90,8 @@ struct l2_exp_cutlass_op { // outVal could be negative due to numerical instability, especially when // calculating self distance. // clamp to 0 to avoid potential NaN in sqrt - outVal = outVal * (outVal > DataT(0.0)); - return sqrt ? raft::sqrt(outVal) : outVal; + //outVal = outVal * (outVal > DataT(0.0)); + return sqrt ? raft::sqrt(outVal * (outVal > DataT(0.0))) : outVal; } __device__ AccT operator()(DataT aData) const noexcept { return aData; } diff --git a/cpp/include/raft/distance/detail/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h similarity index 91% rename from cpp/include/raft/distance/detail/custom_epilogue_with_broadcast.h rename to cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index 6852a4b447..39fa4d3aea 100755 --- a/cpp/include/raft/distance/detail/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -387,7 +387,6 @@ class EpilogueWithBroadcastCustom : void operator()( OutputOp const &output_op, ///< Output operator ElementVector const * broadcast_ptr, ///< Broadcast vector - //OutputTileIterator destination_iterator, ///< Tile iterator for destination AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand @@ -399,25 +398,13 @@ class EpilogueWithBroadcastCustom : BroadcastFragment broadcast_fragment; load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); -#if 0 - if (!output_op.is_source_needed()) { - compute_source_not_needed_( - output_op, - broadcast_fragment, - destination_iterator, - accumulators, - tensor_iterator); - } - else { -#endif - compute_source_needed_( - output_op, - broadcast_fragment, - //destination_iterator, - accumulators, - source_iterator, - tensor_iterator); - //} + + compute_source_needed_( + output_op, + broadcast_fragment, + accumulators, + source_iterator, + tensor_iterator); } private: @@ -649,7 +636,6 @@ class EpilogueWithBroadcastCustom : void compute_source_needed_( OutputOp const &output_op, ///< Output operator BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - //OutputTileIterator destination_iterator, ///< Tile iterator for destination AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand @@ -671,13 +657,6 @@ class EpilogueWithBroadcastCustom : #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { - // - // Load the source - // - - // source_iterator.load(source_fragment); - // ++source_iterator; - // // Convert and store fragment // @@ -717,7 +696,6 @@ class EpilogueWithBroadcastCustom : // Apply output operation // - typename OutputTileIterator::Fragment frag_Z; typename TensorTileIterator::Fragment frag_T; // @@ -728,22 +706,16 @@ class EpilogueWithBroadcastCustom : ++source_iterator; apply_output_operator_( - frag_Z, frag_T, output_op, aligned_accum_fragment[0], source_fragment, broadcast_fragment); + // // Conditionally store fragments // -#if 0 - if (OutputOp::kStoreZ) { - destination_iterator.store(frag_Z); - ++destination_iterator; - } -#endif if (OutputOp::kStoreT) { tensor_iterator.store(frag_T); ++tensor_iterator; @@ -754,18 +726,15 @@ class EpilogueWithBroadcastCustom : /// Helper to invoke the output functor over each vector of output CUTLASS_DEVICE void apply_output_operator_( - typename OutputTileIterator::Fragment &frag_Z, typename TensorTileIterator::Fragment &frag_T, OutputOp const &output_op, typename SharedLoadIterator::Fragment const &frag_AB, typename OutputTileIterator::Fragment const &frag_C, BroadcastFragment const &frag_Broadcast) { - using AccessTypeZ = Array; - using AccessTypeT = Array; + using AccessTypeT = Array; using AccessTypeBroadcast = Array; - AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); AccumulatorAccessType const *frag_AB_ptr = @@ -777,8 +746,6 @@ class EpilogueWithBroadcastCustom : AccessTypeBroadcast const *frag_Broadcast_ptr = reinterpret_cast(&frag_Broadcast); - // int const kOutputOpIterations = - // OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; int const kOutputOpIterations = TensorTileIterator::Fragment::kElements / TensorTileIterator::kElementsPerAccess; @@ -786,7 +753,6 @@ class EpilogueWithBroadcastCustom : for (int i = 0; i < kOutputOpIterations; ++i) { output_op( - frag_Z_ptr[i], frag_T_ptr[i], frag_AB_ptr[i], frag_C_ptr[(i / ThreadMap::Iterations::kColumn)], @@ -803,31 +769,6 @@ class EpilogueWithBroadcastCustom : typename SharedLoadIterator::Fragment const &frag_AB, BroadcastFragment const &frag_Broadcast) { - using AccessTypeZ = Array; - using AccessTypeT = Array; - using AccessTypeBroadcast = Array; - - AccessTypeZ *frag_Z_ptr = reinterpret_cast(&frag_Z); - AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); - - AccumulatorAccessType const *frag_AB_ptr = - reinterpret_cast(&frag_AB); - - AccessTypeBroadcast const *frag_Broadcast_ptr = - reinterpret_cast(&frag_Broadcast); - - int const kOutputOpIterations = - OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) { - - output_op( - frag_Z_ptr[i], - frag_T_ptr[i], - frag_AB_ptr[i], - frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); - } } }; diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh similarity index 53% rename from cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh rename to cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh index be20a34ae0..8bcf006528 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh @@ -36,8 +36,8 @@ #include #include -#include -#include +#include +#include #include #include @@ -50,11 +50,11 @@ template -void cutlassFusedL2NNKernel(const DataT* x, +void cutlassFusedDistanceNN(const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, @@ -66,30 +66,27 @@ void cutlassFusedL2NNKernel(const DataT* x, IdxT ldd, OutT* dOutput, int* mutexes, - FinalLambda fin_op, + CGReduceOpT cg_reduce_op, DistanceFn dist_op, ReduceOpT redOp, KVPReduceOpT pairRedOp, cudaStream_t stream) { using EpilogueOutputOp = - cutlass::epilogue::thread::FusedL2NNEpilogueElementwise::value, // Elements per access 1 + //128 / cutlass::sizeof_bits::value, + 1, // Elements per access 1 DistanceFn, - FinalLambda, + CGReduceOpT, ReduceOpT, KVPReduceOpT>; constexpr int batch_count = 1; - typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op, redOp, pairRedOp, mutexes); - - const DataT *a, *b; - - IdxT gemm_lda, gemm_ldb; + typename EpilogueOutputOp::Params epilog_op_param(dist_op, cg_reduce_op, redOp, pairRedOp, mutexes); // Number of pipelines you want to use constexpr int NumStages = 3; @@ -97,13 +94,12 @@ void cutlassFusedL2NNKernel(const DataT* x, constexpr int Alignment = VecLen; // default initialize problem size with row major inputs - //auto problem_size = cutlass::gemm::GemmCoord(n, m, k); auto problem_size = cutlass::gemm::GemmCoord(m, n, k); constexpr bool isRowMajor = true; - using cutlassDistKernel = - typename cutlass::gemm::kernel::FusedL2NNGemm::GemmKernel; -#if 0 - using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; - a = y; - b = x; - gemm_lda = ldb; - gemm_ldb = lda; - constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; + using fusedDistanceNN = cutlass::gemm::device::GemmGrouped; - typename cutlassDist::Arguments arguments{ - mode, - problem_size, - batch_count, - epilog_op_param, - a, - b, - xn, // C matrix eq vector param, which here is A norm - nullptr, // tensor_Z, - (DataT*)yn, // this is broadcast vec, which is required to be non-const param - dOutput, // Output distance matrix - (int64_t)0, // batch stride A - (int64_t)0, // batch stride B - (int64_t)0, // batch stride Norm A - (int64_t)0, - (int64_t)0, // batch stride Norm B - (int64_t)0, // batch stride Output - (int64_t)gemm_lda, // stride A - (int64_t)gemm_ldb, // stride B - 1, // stride A norm - 0, // this is no-op for Z - 0, // This must be zero - (int64_t)ldd // stride Output matrix - }; + int num_blocks_per_sm = fusedDistanceNN::maximum_active_blocks(); + int num_sms = raft::getMultiProcessorCount(); + int num_blocks = num_blocks_per_sm * num_sms; + constexpr int mmaShapeM = fusedDistanceNNKernel::Mma::Shape::kM; + auto thread_blocks = std::max(num_blocks, int((problem_size.m() - 1 + mmaShapeM) / mmaShapeM)); - // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = cutlassDist::get_workspace_size(arguments); - // Allocate workspace memory - rmm::device_uvector workspace(workspace_size, stream); - // Instantiate CUTLASS kernel depending on templates - cutlassDist cutlassDist_op; - // Check the problem size is supported or not - RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); - // Initialize CUTLASS kernel with arguments and workspace pointer - RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); - // Launch initialized CUTLASS kernel - RAFT_CUTLASS_TRY(cutlassDist_op()); -#else - - - using cutlassDist = cutlass::gemm::device::GemmGrouped; - - a = x; - b = y; - gemm_lda = lda; - gemm_ldb = ldb; - int num_blocks = cutlassDist::maximum_active_blocks(); - int num_sms = raft::getMultiProcessorCount(); - num_blocks = num_blocks * num_sms; - auto thread_blocks = std::max(num_blocks, int((problem_size.m() - 1 + cutlassDistKernel::Mma::Shape::kM)/ cutlassDistKernel::Mma::Shape::kM)); - //printf("num blocks = %d sms = %d thread_blocks_sel = %d shapekM = %d\n", num_blocks, num_sms, (int)thread_blocks, (int)cutlassDistKernel::Mma::Shape::kM); - //rmm::device_uvector problem_sizes(sizeof(decltype(problem_size)), stream); - //raft::copy(problem_sizes.data(), &problem_size, 1, stream); - typename cutlassDist::Arguments arguments{ - //problem_sizes.data(), + typename fusedDistanceNN::Arguments arguments{ problem_size, - batch_count, + batch_count, // num of problems. thread_blocks, epilog_op_param, - a, - b, + x, + y, xn, // C matrix eq vector param, which here is A norm (DataT*)yn, // this is broadcast vec, which is required to be non-const param dOutput, // Output distance matrix - (int64_t)gemm_lda, // stride A - (int64_t)gemm_ldb, // stride B - (int64_t)1, // stride A norm - (int64_t)ldd // stride Output matrix + (int64_t)lda, // stride A + (int64_t)ldb, // stride B + (int64_t)1, // stride A norm + (int64_t)ldd // stride Output matrix }; // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = cutlassDist::get_workspace_size(arguments); + size_t workspace_size = fusedDistanceNN::get_workspace_size(arguments); // Allocate workspace memory rmm::device_uvector workspace(workspace_size, stream); // Instantiate CUTLASS kernel depending on templates - cutlassDist cutlassDist_op; + fusedDistanceNN fusedDistanceNN_op; // Check the problem size is supported or not - RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); + RAFT_CUTLASS_TRY(fusedDistanceNN_op.can_implement(arguments)); // Initialize CUTLASS kernel with arguments and workspace pointer - RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); + RAFT_CUTLASS_TRY(fusedDistanceNN_op.initialize(arguments, workspace.data(), stream)); // Launch initialized CUTLASS kernel - RAFT_CUTLASS_TRY(cutlassDist_op.run(stream)); -#endif - + RAFT_CUTLASS_TRY(fusedDistanceNN_op.run(stream)); } }; // namespace detail diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh similarity index 87% rename from cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh rename to cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh index 282ac7e906..96c6697c02 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh @@ -39,12 +39,10 @@ operation. #include #include #include -//#include -#include +#include -//#include -#include -#include +#include +#include //////////////////////////////////////////////////////////////////////////////// @@ -65,7 +63,7 @@ template -struct FusedL2NNEpilogue { +struct FusedDistanceNNEpilogue { /// Use defaults related to the existing epilogue using Base = DefaultEpilogueTensorOp; @@ -73,8 +71,6 @@ struct FusedL2NNEpilogue { // // Stores the result z = (y = GEMM(A, B, C), broadcast) // - // using RowNormTileIterator = cutlass::epilogue::threadblock:: - // PredicatedTileIteratorNormVec; using RowNormTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorNormVecSmem< typename Base::OutputTileThreadMap, ElementOutput, LayoutT>; diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh similarity index 93% rename from cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh rename to cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh index e8c686bee6..00fafe5fa1 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -54,7 +54,7 @@ template -class FusedL2NNEpilogueElementwise { +class FusedDistanceNNEpilogueElementwise { public: using ElementOutput = ElementC_; using ElementC = ElementC_; @@ -72,7 +72,9 @@ class FusedL2NNEpilogueElementwise { using FragmentCompute = Array; using FragmentC = Array; using FragmentZ = Array; - using FragmentT = Array; + using OutValT = typename CGReduceOp::AccTypeT; + //using FragmentT = Array; + using FragmentT = Array; using FragmentOutput = FragmentZ; @@ -129,7 +131,7 @@ class FusedL2NNEpilogueElementwise { /// Constructor from Params CUTLASS_HOST_DEVICE - FusedL2NNEpilogueElementwise(Params const& params) + FusedDistanceNNEpilogueElementwise(Params const& params) : elementwise_op(params.dist_op_), pair_redop(params.pair_redop_), red_op(params.red_op_) { } @@ -148,8 +150,7 @@ class FusedL2NNEpilogueElementwise { /// Applies the operation when is_source_needed() is true CUTLASS_HOST_DEVICE - void operator()(FragmentZ& frag_Z, - FragmentT& frag_T, + void operator()(FragmentT& frag_T, FragmentAccumulator const& AB, FragmentC const& frag_C, FragmentCompute const& V) const @@ -163,7 +164,8 @@ class FusedL2NNEpilogueElementwise { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerAccess; ++i) { ElementCompute res_Z = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); - red_op.init(&frag_T[i], res_Z); + //red_op.init(&frag_T[i], res_Z); + frag_T[i] = res_Z; } } diff --git a/cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h b/cpp/include/raft/distance/detail/fused_distance_nn/fusedL2NN_gemm_with_fused_epilogue.h similarity index 100% rename from cpp/include/raft/distance/detail/fusedL2NN_gemm_with_fused_epilogue.h rename to cpp/include/raft/distance/detail/fused_distance_nn/fusedL2NN_gemm_with_fused_epilogue.h diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h similarity index 86% rename from cpp/include/raft/distance/detail/fused_l2_nn_gemm.h rename to cpp/include/raft/distance/detail/fused_distance_nn/gemm.h index d68f1bc8a5..7b9858d9bc 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h @@ -22,9 +22,9 @@ #include #include -#include -#include -#include +//#include +#include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -55,22 +55,20 @@ template < int Stages, /// data layout row/column major of inputs bool isRowMajor> -struct FusedL2NNGemm { +struct FusedDistanceNNGemm { // This struct is specialized for fp32/3xTF32 /// Threadblock-level tile size (concept: GemmShape) // <- threadblock tile M = 32, N = 64, K = 16 - //using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 16>; // this is more performant for grouped GEMM - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; // this is more performant for non-grouped GEMM - //using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; // SHAPE for less reg pressure grouped GEMM + using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 16>; // this is more performant for grouped GEMM + //using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; // this shape has high occupancy but less perf /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute // <- warp tile M = 64, N = 64, K = 16 - //using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // this is more performant for grouped GEMM - //using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; // // SHAPE for less reg pressure grouped GEMM - using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // // SHAPE for less reg pressure grouped GEMM - //using WarpShape = cutlass::gemm::GemmShape<16, 64, 16>; // // this is more performant for non-grouped GEMM + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // this is more performant for grouped GEMM + //using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // this shape has high occupancy but less perf + /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op @@ -90,7 +88,6 @@ struct FusedL2NNGemm { // This code section describes how threadblocks are scheduled on GPU /// Threadblock-level swizzling operator - //using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>; using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; /// data layout for final output matrix. @@ -129,7 +126,7 @@ struct FusedL2NNGemm { Operator>::GemmKernel; // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::FusedL2NNEpilogue< + using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< typename GemmBase::Epilogue::Shape, typename GemmBase::Epilogue::WarpMmaOperator, GemmBase::Epilogue::kPartitionsK, @@ -142,9 +139,7 @@ struct FusedL2NNGemm { // Compose the GEMM kernel - // using GemmKernel = - // FusedL2NNWithFusedEpilogue; - using GemmKernel = FusedL2NNWithGemmGrouped; }; @@ -163,7 +158,7 @@ template < int Stages, /// data layout row/column major of inputs bool isRowMajor> -struct FusedL2NNGemm::GemmKernel; // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::FusedL2NNEpilogue< + using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< typename GemmBase::Epilogue::Shape, typename GemmBase::Epilogue::WarpMmaOperator, GemmBase::Epilogue::kPartitionsK, @@ -247,9 +242,7 @@ struct FusedL2NNGemm::Epilogue; // Compose the GEMM kernel - // using GemmKernel = - // FusedL2NNWithFusedEpilogue; - using GemmKernel = FusedL2NNWithGemmGrouped; }; diff --git a/cpp/include/raft/distance/detail/fused_l2_nn_gemm_grouped_custom.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h similarity index 82% rename from cpp/include/raft/distance/detail/fused_l2_nn_gemm_grouped_custom.h rename to cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index 7ee01fa6fe..b08cc0a9ab 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn_gemm_grouped_custom.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -35,17 +35,17 @@ #pragma once -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/trace.h" +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -60,7 +60,7 @@ template -struct FusedL2NNWithGemmGrouped { +struct FusedDistanceNNPersistent { public: using Mma = Mma_; using Epilogue = Epilogue_; @@ -263,8 +263,7 @@ struct FusedL2NNWithGemmGrouped { CUTLASS_HOST_DEVICE Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : //problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count), - problem_size(args.problem_sizes), + : problem_size(args.problem_sizes), threadblock_count(args.threadblock_count), output_op(args.output_op), params_A(args.lda), @@ -289,8 +288,6 @@ struct FusedL2NNWithGemmGrouped { CUTLASS_HOST_DEVICE void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) { - // problem_visitor = typename ProblemVisitor::Params( - // args.problem_sizes, args.problem_count, workspace, tile_count); threadblock_count = args.threadblock_count; output_op = args.output_op; ptr_A = const_cast(args.ptr_A); @@ -307,34 +304,25 @@ struct FusedL2NNWithGemmGrouped { } }; - struct epilogue_SharedStorage { - typename Epilogue::SharedStorage epilogue; - //typename Epilogue::TensorTileIterator::SharedStorage reduced_store; - }; /// Shared memory storage structure struct SharedStorage { union { typename Mma::SharedStorage main_loop; - epilogue_SharedStorage epilogue_combined_store; + typename Epilogue::SharedStorage epilogue; } kernel; - // ProblemVisitor shared storage can't be overlapped with others - //typename ProblemVisitor::SharedStorage problem_visitor; typename Epilogue::TensorTileIterator::SharedStorage reduced_store; typename Epilogue::OutputTileIterator::SharedStorage rownorm_store; - }; - protected: - //uint32_t tile_idx; public: // // Methods // CUTLASS_DEVICE - FusedL2NNWithGemmGrouped() {} + FusedDistanceNNPersistent() {} /// Determines whether kernel satisfies alignment static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) @@ -351,31 +339,19 @@ struct FusedL2NNWithGemmGrouped { } CUTLASS_DEVICE - static uint32_t tile_count_(const cutlass::MatrixCoord& grid) { + static uint32_t tile_count(const cutlass::MatrixCoord& grid) { return grid.row() * grid.column(); } /// Get the grid shape CUTLASS_DEVICE - static cutlass::MatrixCoord grid_shape_(const cutlass::gemm::GemmCoord& problem) { + static cutlass::MatrixCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { return cutlass::MatrixCoord( ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN)); } - CUTLASS_DEVICE - bool custom_next_tile_(const cutlass::gemm::GemmCoord &problem_size, uint32_t tile_idx_) { - // Check whether the tile to compute is within the range of the current problem. - const auto grid = grid_shape_(problem_size); - const uint32_t problem_chunk = (tile_count_(grid) - 1 + gridDim.x) / gridDim.x; - const uint32_t problem_chunk_end = blockIdx.x * problem_chunk + problem_chunk; - if (tile_idx_ < problem_chunk_end) { - return true; - } - - return false; - } /// Executes one GEMM CUTLASS_DEVICE @@ -394,17 +370,12 @@ struct FusedL2NNWithGemmGrouped { using ElementOut = typename Epilogue::TensorTileIterator::Element; using LongIndexOut = typename Epilogue::TensorTileIterator::LongIndex; using OutValTy = typename Epilogue::TensorTileIterator::OutValT; - // - // Problem visitor. - // - // ProblemVisitor problem_visitor( - // params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); const GemmCoord& problem_size = params.problem_size; - const uint32_t problem_chunk = (tile_count_(grid_shape_(problem_size)) - 1 + gridDim.x) / gridDim.x; + const uint32_t problem_chunk = (tile_count(grid_shape(problem_size)) - 1 + gridDim.x) / gridDim.x; const uint32_t problem_chunk_end = blockIdx.x * problem_chunk + problem_chunk; - const auto grid_shape = grid_shape_(problem_size); - typename LayoutB::Index column = ((blockIdx.x * problem_chunk) % grid_shape.column()) * Mma::Shape::kN; + const auto grid_shape_ = grid_shape(problem_size); + typename LayoutB::Index column = ((blockIdx.x * problem_chunk) % grid_shape_.column()) * Mma::Shape::kN; { ElementOut* shared_elem_arr_ = shared_storage.reduced_store.data(); constexpr auto maxVal_ = std::numeric_limits::max(); @@ -419,7 +390,7 @@ struct FusedL2NNWithGemmGrouped { { ElementC* shared_elem_arr = shared_storage.rownorm_store.data(); if (column) { - typename LayoutB::Index row = ((blockIdx.x * problem_chunk) / grid_shape.column()) * Mma::Shape::kM; + typename LayoutB::Index row = ((blockIdx.x * problem_chunk) / grid_shape_.column()) * Mma::Shape::kM; uint8_t* first_tile_byte_pointer_ = reinterpret_cast(params.ptr_C) + typename LayoutB::LongIndex(row) * typename LayoutB::LongIndex(sizeof(ElementC)); @@ -436,23 +407,14 @@ struct FusedL2NNWithGemmGrouped { // Outer 'persistent' loop to iterate over tiles for (uint32_t tile_idx = blockIdx.x * problem_chunk; tile_idx < problem_chunk_end; tile_idx++) { - const auto grid_shape = grid_shape_(problem_size); + const auto grid_shape_ = grid_shape(problem_size); cutlass::MatrixCoord threadblock_offset( - int(tile_idx / grid_shape.column()) * Mma::Shape::kM, - int(tile_idx % grid_shape.column()) * Mma::Shape::kN); -#if 1 - //const bool isNextTile = custom_next_tile_(problem_size, tile_idx + 1); + int(tile_idx / grid_shape_.column()) * Mma::Shape::kM, + int(tile_idx % grid_shape_.column()) * Mma::Shape::kN); + const bool isNextTile = ((tile_idx + 1) < problem_chunk_end); - //const bool doesRowChange = ((int((tile_idx + 1) / grid_shape.column()) * Mma::Shape::kM) == threadblock_offset.row()); const bool doesRowChange = ((threadblock_offset.column() + Mma::Shape::kN) >= problem_size.n()); const bool do_gmem_reduce = (doesRowChange || !isNextTile) ? true : false; -#endif - // Load element pointers. Exchange pointers and strides if working on the transpose - //const ElementA* ptr_A = reinterpret_cast((kTransposed ? params.ptr_B : params.ptr_A)); - //typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb : params.lda); - - //const ElementB* ptr_B = reinterpret_cast((kTransposed ? params.ptr_A : params.ptr_B)); - //typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda : params.ldb); ElementA* ptr_A = static_cast(params.ptr_A); ElementB* ptr_B = static_cast(params.ptr_B); @@ -511,28 +473,10 @@ struct FusedL2NNWithGemmGrouped { static_cast(params.ptr_Vector); // Tile iterator loading from source tensor. -#if 1 typename Epilogue::OutputTileIterator iterator_rownorm( shared_storage.rownorm_store, params.params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset); -#else - typename Epilogue::OutputTileIterator iterator_rownorm( - params.params_C, ptr_C, problem_size.mn(), thread_idx, - threadblock_offset); -#endif - - // Tile iterator writing to destination tensor. - // typename Epilogue::OutputTileIterator::Params params_D(0); - // ElementC* ptr_D = nullptr; -#if 1 - // typename Epilogue::OutputTileIterator iterator_D( - // shared_storage.rownorm_store, - // params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset); -#else - typename Epilogue::OutputTileIterator iterator_D( - params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset); -#endif // Additional tensor to load from typename Epilogue::TensorTileIterator tensor_iterator( @@ -545,7 +489,7 @@ struct FusedL2NNWithGemmGrouped { do_gmem_reduce, threadblock_offset); - Epilogue epilogue(shared_storage.kernel.epilogue_combined_store.epilogue, thread_idx, warp_idx, lane_idx); + Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); // Execute the epilogue operator to update the destination tensor. // Move to appropriate location for this output tile diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec_smem.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h similarity index 58% rename from cpp/include/raft/distance/detail/predicated_tile_iterator_normvec_smem.h rename to cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h index 1c62f1a061..20e44521b9 100755 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec_smem.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h @@ -26,7 +26,6 @@ This file contains a customized version of PredicatedTileIterator from CUTLASS 2 This way the same normalization data is used across all columns in a row. */ - #pragma once #include @@ -82,22 +81,19 @@ class PredicatedTileIteratorNormVecSmem { static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; static int const kThreads = ThreadMap::kThreads; static int const kIterations = ThreadMap::Count::kTile; + // static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * + // ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + // kIterations; + static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * - kIterations; + ThreadMap::Count::kTile * ThreadMap::Delta::kRow; static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); -// static_assert((ThreadMap::Iterations::kRow == 1) || (ThreadMap::Iterations::kRow == 2) -// || (ThreadMap::Iterations::kRow == 4) , "ThreadMap::Iterations::kRow must be 1, 2 or 4"); - /// Fragment object - // using Fragment = Array; using Fragment = Array(shared_elem_arr + row, gmem_ptr + row, guard); cutlass::arch::cp_async_wait<0>(); } - //__syncthreads(); } // Initialize internal state counter @@ -321,12 +314,6 @@ class PredicatedTileIteratorNormVecSmem { Element* shared_elem_arr = shared_storage_.data(); -#if 0 - Element row_vals[ThreadMap::Iterations::kRow]; - //static int constexpr ldsPerAccess = sizeof(Element) == 8 ? 2 : ThreadMap::Iterations::kRow; - int iter_row_ = ((thread_start_row_) % total_rows); - raft::lds(row_vals, shared_elem_arr + iter_row_); -#endif CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { CUTLASS_PRAGMA_UNROLL @@ -339,211 +326,21 @@ class PredicatedTileIteratorNormVecSmem { int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; int iter_row = ((row_offset + thread_start_row_) % total_rows); - (*frag_ptr)[frag_row_idx] = shared_elem_arr[iter_row]; - - } - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const - { - uint8_t* byte_pointer = byte_pointer_; - AccessType const* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - if (ScatterD && row_guard) { - assert(indices_); - - memory_pointer = reinterpret_cast( - byte_pointer + byte_offset + - LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); - } + Element val = shared_elem_arr[iter_row]; CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - if (UseCUDAStore) { - if (guard) { - memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; - } - } else { - cutlass::arch::global_store( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { byte_pointer += params_.increment_row; } + for (int i = 0; i < kElementsPerAccess; ++i) { + (*frag_ptr)[frag_row_idx + i] = val; } } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; } } } - /// Stores a fragment to memory - CUTLASS_DEVICE - void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } - /// Loads a fragment from memory CUTLASS_DEVICE - void downsample_load_with_byte_offset(Fragment& frag, - int64_t byte_offset, - int convolution_P, - int convolution_Q, - int add_P, - int add_Q, - int problem_N) const - { - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - int output_row = row_offset + thread_start_row_; - int output_N = output_row / (convolution_P * convolution_Q); - int output_PQ = output_row % (convolution_P * convolution_Q); - int output_P = output_PQ / convolution_Q; - int output_Q = output_PQ % convolution_Q; - - int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + - (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; - - int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void upsample_load_with_byte_offset(Fragment& frag, - int64_t byte_offset, - int convolution_P, - int convolution_Q, - int add_P, - int add_Q, - int problem_N) const - { - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - int output_row = row_offset + thread_start_row_; - int output_N = output_row / (convolution_P * convolution_Q); - int output_PQ = output_row % (convolution_P * convolution_Q); - int output_P = output_PQ / convolution_Q; - int output_Q = output_PQ % convolution_Q; - int row_add_P = add_P; - int row_add_Q = add_Q; - if (output_P > convolution_P - 2) row_add_P = 0; - if (output_Q > convolution_Q - 2) row_add_Q = 0; - - int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + - ((output_P + row_add_P) / 2) * (convolution_Q / 2) + - (output_Q + row_add_Q) / 2; - - int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } + void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } CUTLASS_DEVICE MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h similarity index 62% rename from cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h rename to cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 69b201c451..6ba884f318 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -105,12 +105,13 @@ class PredicatedTileIteratorReducedVec { ThreadMap::Count::kTile * ThreadMap::Delta::kRow; /// Fragment object using Fragment = - Array; - /// Memory access size + // Memory access size using AccessType = AlignedArray; + using AccessTypeValT = AlignedArray; // // Parameters struct @@ -146,7 +147,8 @@ class PredicatedTileIteratorReducedVec { /// Mask object struct Mask { - static int const kCount = ThreadMap::Iterations::kColumn; + //static int const kCount = ThreadMap::Iterations::kColumn; + static int const kCount = ThreadMap::Iterations::kColumn * kElementsPerAccess; /// Predicate state bool predicates[kCount]; @@ -207,11 +209,10 @@ class PredicatedTileIteratorReducedVec { struct select_reduce { /// Performs reduction and stores a reduced output to memory CUTLASS_DEVICE - select_reduce(OutT value, cg_reduce_op_t reduce_op, + select_reduce(OutT value, ValT prev_red_val, cg_reduce_op_t reduce_op, cg_group_t cg_warp_group, OutT& shmem_ptr) { - OutT element = reduce_op(shmem_ptr, value); - if (cg_warp_group.any(element == value)) { + if (cg_warp_group.any(reduce_op.isAmin(value, prev_red_val))) { OutT reduced_val = cg::reduce(cg_warp_group, value, reduce_op); if (cg_warp_group.thread_rank() == 0) { shmem_ptr = reduced_val; @@ -226,12 +227,12 @@ class PredicatedTileIteratorReducedVec { using Ty = raft::KeyValuePair; CUTLASS_DEVICE - select_reduce(Ty val_to_red, cg_reduce_op_t cg_reduce_op, + select_reduce(Ty val_to_red, float prev_red_val, cg_reduce_op_t cg_reduce_op, cg_group_t cg_warp_group, Ty& shmem_ptr) { ValT val = val_to_red.value; - Ty element = cg_reduce_op(shmem_ptr, val_to_red); - if (cg_warp_group.any(element.value == val_to_red.value)) { + + if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); bool pred = (reduced_val == val); auto subTile = cg::binary_partition(cg_warp_group, pred); @@ -250,12 +251,12 @@ class PredicatedTileIteratorReducedVec { using Ty = raft::KeyValuePair; CUTLASS_DEVICE - select_reduce(Ty val_to_red, cg_reduce_op_t cg_reduce_op, + select_reduce(Ty val_to_red, double prev_red_val, cg_reduce_op_t cg_reduce_op, cg_group_t cg_warp_group, Ty& shmem_ptr) { ValT val = val_to_red.value; - Ty element = cg_reduce_op(shmem_ptr, val_to_red); - if (cg_warp_group.any(element.value == val_to_red.value)) { + + if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); bool pred = (reduced_val == val); auto subTile = cg::binary_partition(cg_warp_group, pred); @@ -320,7 +321,6 @@ class PredicatedTileIteratorReducedVec { // // Methods // - //static OutValT const maxVal = std::numeric_limits::max(); public: // @@ -353,9 +353,10 @@ class PredicatedTileIteratorReducedVec { // Initialize predicates CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { - mask_.predicates[c] = - ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + for (int c = 0; c < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++c) { + int columnPerAccess = (c / kElementsPerAccess); + int columnWithinPerAccess = c % kElementsPerAccess; + mask_.predicates[c] = ((thread_offset.column() + ThreadMap::Delta::kColumn * columnPerAccess + columnWithinPerAccess) < extent.column()); } if (threadblock_offset.column() == 0) { @@ -401,8 +402,13 @@ class PredicatedTileIteratorReducedVec { // single lock per block for multiple rows if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { // acquire mutex lock. - while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) - ; + unsigned int ns = 8; + while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) { + __nanosleep(ns); + if (ns < 256) { + ns *= 2; + } + } } __syncthreads(); @@ -437,81 +443,16 @@ class PredicatedTileIteratorReducedVec { byte_pointer_ += pointer_offset * sizeof_bits::value / 8; } - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const - { -#if 0 - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - if (ScatterD && row_guard) { - assert(indices_); - - memory_pointer = reinterpret_cast( - byte_pointer + byte_offset + - LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); - } - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - if (column == 0) { - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[0], - guard); - } else { - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn]; - } - - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { byte_pointer += params_.increment_row; } - } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } -#endif - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } - /// Performs reduction and Stores a reduced output to memory CUTLASS_DEVICE void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const { - AccessType* frag_ptr = reinterpret_cast(&frag); + AccessTypeValT* frag_ptr = reinterpret_cast(&frag); - cg::thread_block cta = cg::this_thread_block(); - cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta); - // constexpr int tile_width = (32 / ThreadMap::Delta::kColumn) ? 32 : 16; - // cg::thread_block_tile tile32 = cg::tiled_partition(cta); + cg::thread_block cta = cg::this_thread_block(); + // tile_width 16 is required if kElementPerAccess > 1 + constexpr int tile_width = (32 / ThreadMap::Delta::kColumn) ? 32 : 16; + cg::thread_block_tile tile32 = cg::tiled_partition(cta); EpilogueOpParams const& user_params = params_.user_param; using cg_reduce_t = decltype(user_params.cg_reduce_op); @@ -520,12 +461,6 @@ class PredicatedTileIteratorReducedVec { Element* shared_elem_arr = shared_storage_.data(); constexpr auto maxVal = std::numeric_limits::max(); - // if (threadIdx.x == 0 && blockIdx.x == 0) { - // printf("\nIterations::kColumn = %d Iterations::kRow = %d Iterations::kGroup = %d Iterations::kCluster = %d kElementsPerAccess = %d\n", - // ThreadMap::Iterations::kColumn, ThreadMap::Iterations::kRow, ThreadMap::Iterations::kGroup, ThreadMap::Iterations::kCluster, kElementsPerAccess); - // printf("\nDelta::kColumn = %d Delta::kRow = %d Delta::kGroup = %d Delta::kCluster = %d kElementsPerAccess = %d tile_count = %d total_rows = %d\n", - // ThreadMap::Delta::kColumn, ThreadMap::Delta::kRow, ThreadMap::Delta::kGroup, ThreadMap::Delta::kCluster, kElementsPerAccess, kIterations, total_rows); - // } CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { CUTLASS_PRAGMA_UNROLL @@ -540,27 +475,31 @@ class PredicatedTileIteratorReducedVec { bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn; + const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn * kElementsPerAccess; Element red_val; user_params.red_op_.init(&red_val, maxVal); if (row_guard) { + + const int iter_row = ((row_offset + thread_start_row_) % total_rows); + const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); + CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++column) { + int columnPerAccess = column / kElementsPerAccess; + int columnWithPerAccess = column % kElementsPerAccess; bool guard = mask_.predicates[column]; if (guard) { - const auto key_id = thread_start_column_ + ThreadMap::Delta::kColumn * column; + const OutIdxT key_id = thread_start_column_ + ThreadMap::Delta::kColumn * columnPerAccess + columnWithPerAccess; const int frag_col_idx = frag_idx + column; - user_params.red_op_.init_key((*frag_ptr)[frag_col_idx], key_id); - user_params.red_op_(key_id, &red_val, (*frag_ptr)[frag_col_idx]); + + Element this_val; + user_params.red_op_.init(&this_val, (*frag_ptr)[frag_col_idx]); + user_params.red_op_.init_key(this_val, key_id ); + user_params.red_op_(key_id , &red_val, this_val); } } - - const int iter_row = ((row_offset + thread_start_row_) % total_rows); - // if (blockIdx.x == 0) { - // printf("iter_row = %d thread_start_row_ = %d row_offset = %d tid = %d warp_id = %d\n", (int)iter_row, (int)thread_start_row_, (int)row_offset, (int)threadIdx.x, (int)threadIdx.x / 32); - // } select_reduce red_obj( - red_val, user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); + red_val, prev_red_val, user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); } } } @@ -571,139 +510,6 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE void store(Fragment& frag) const { store_with_byte_offset(frag, 0); } - /// Loads a fragment from memory - CUTLASS_DEVICE - void downsample_load_with_byte_offset(Fragment& frag, - int64_t byte_offset, - int convolution_P, - int convolution_Q, - int add_P, - int add_Q, - int problem_N) const - { -#if 0 - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - int output_row = row_offset + thread_start_row_; - int output_N = output_row / (convolution_P * convolution_Q); - int output_PQ = output_row % (convolution_P * convolution_Q); - int output_P = output_PQ / convolution_Q; - int output_Q = output_PQ % convolution_Q; - - int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + - (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; - - int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } -#endif - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void upsample_load_with_byte_offset(Fragment& frag, - int64_t byte_offset, - int convolution_P, - int convolution_Q, - int add_P, - int add_Q, - int problem_N) const - { -#if 0 - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - int output_row = row_offset + thread_start_row_; - int output_N = output_row / (convolution_P * convolution_Q); - int output_PQ = output_row % (convolution_P * convolution_Q); - int output_P = output_PQ / convolution_Q; - int output_Q = output_PQ % convolution_Q; - int row_add_P = add_P; - int row_add_Q = add_Q; - if (output_P > convolution_P - 2) row_add_P = 0; - if (output_Q > convolution_Q - 2) row_add_Q = 0; - - int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + - ((output_P + row_add_P) / 2) * (convolution_Q / 2) + - (output_Q + row_add_Q) / 2; - - int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } -#endif - } - CUTLASS_DEVICE MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 25b325b1d5..abdce95b64 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include #include @@ -72,6 +72,15 @@ struct MinAndDistanceReduceOpImpl { DI void init_key(DataT& out, LabelT idx) const { return; } DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } + + DI DataT get_value(KVP& out) const + { + return out.value;; + } + DI DataT get_value(DataT& out) const + { + return out; + } }; template @@ -274,7 +283,6 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, template struct kvp_cg_min_reduce_op { typedef typename raft::KeyValuePair KVP; - //static const AccType maxVal; maxVal(std::numeric_limits::max()) __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; @@ -282,7 +290,11 @@ struct kvp_cg_min_reduce_op { using IndexT = Index; // functor signature. __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } - __host__ __device__ AccType operator()(AccType a, AccType b) const { return a < b ? a : b; } + +__host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } + +__host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } + }; template Date: Fri, 28 Apr 2023 01:10:49 -0700 Subject: [PATCH 26/48] remove the data parallel fusedL2NN cutlass source, fix the connect_components fusedL2NN struct to adhere to new functions --- .../fusedL2NN_gemm_with_fused_epilogue.h | 734 ------------------ .../neighbors/detail/connect_components.cuh | 11 +- 2 files changed, 10 insertions(+), 735 deletions(-) delete mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/fusedL2NN_gemm_with_fused_epilogue.h diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fusedL2NN_gemm_with_fused_epilogue.h b/cpp/include/raft/distance/detail/fused_distance_nn/fusedL2NN_gemm_with_fused_epilogue.h deleted file mode 100644 index b933ab9a7b..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/fusedL2NN_gemm_with_fused_epilogue.h +++ /dev/null @@ -1,734 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Gemm kernel with fused reduction operation. - -This file contains a customized version of GemmWithFusedEpilogue from CUTLASS 2.9.0 -(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h) -* Changes: --- added additional input parameter to params_Tensor constructor, - for passing user inputs to PredicatedTileIterator of reduced output values. -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FusedL2NNWithFusedEpilogue { - public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformB; - using Operator = typename Mma::Operator; - - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - /// Split-K preserves splits that are 128b aligned - static int const kSplitKAlignment = - const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); - - // - // Structures - // - - /// Argument structure - struct Arguments { - // - // Data members - // - - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; - - typename EpilogueOutputOp::Params epilogue; - - void const* ptr_A; - void const* ptr_B; - void const* ptr_C; - void* ptr_D; - - void* ptr_Vector; - void* ptr_Tensor; - - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_C; - int64_t batch_stride_D; - int64_t batch_stride_Vector; - int64_t batch_stride_Tensor; - - typename LayoutA::Stride::Index lda; - typename LayoutB::Stride::Index ldb; - typename LayoutC::Stride::Index ldc; - typename LayoutC::Stride::Index ldd; - typename LayoutC::Stride::Index ldr; - typename LayoutC::Stride::Index ldt; - - // - // Methods - // - - Arguments() - : mode(GemmUniversalMode::kGemm), - batch_count(1), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_D(nullptr) - { - } - - /// constructs an arguments structure - Arguments(GemmUniversalMode mode, - GemmCoord problem_size, - int batch_count, - typename EpilogueOutputOp::Params epilogue, - void const* ptr_A, - void const* ptr_B, - void const* ptr_C, - void* ptr_D, - void* ptr_Vector, - void* ptr_Tensor, - int64_t batch_stride_A, - int64_t batch_stride_B, - int64_t batch_stride_C, - int64_t batch_stride_D, - int64_t batch_stride_Vector, - int64_t batch_stride_Tensor, - typename LayoutA::Stride::Index lda, - typename LayoutB::Stride::Index ldb, - typename LayoutC::Stride::Index ldc, - typename LayoutC::Stride::Index ldd, - typename LayoutC::Stride::Index ldr, - typename LayoutC::Stride::Index ldt) - : mode(mode), - problem_size(problem_size), - batch_count(batch_count), - epilogue(epilogue), - ptr_A(ptr_A), - ptr_B(ptr_B), - ptr_C(ptr_C), - ptr_D(ptr_D), - ptr_Vector(ptr_Vector), - ptr_Tensor(ptr_Tensor), - batch_stride_A(batch_stride_A), - batch_stride_B(batch_stride_B), - batch_stride_C(batch_stride_C), - batch_stride_D(batch_stride_D), - batch_stride_Vector(batch_stride_Vector), - batch_stride_Tensor(batch_stride_Tensor), - lda(lda), - ldb(ldb), - ldc(ldc), - ldd(ldd), - ldr(ldr), - ldt(ldt) - { - CUTLASS_TRACE_HOST( - "FusedL2NNWithFusedEpilogue::Arguments::Arguments() - problem_size: " << problem_size); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void*)this->ptr_Reduction); - CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void*)this->ptr_Tensor); - CUTLASS_TRACE_HOST(" ldr: " << this->ldr); - CUTLASS_TRACE_HOST(" ldt: " << this->ldt); - } - - /// Returns arguments for the transposed problem - Arguments transposed_problem() const - { - Arguments args(*this); - - std::swap(args.problem_size.m(), args.problem_size.n()); - std::swap(args.ptr_A, args.ptr_B); - std::swap(args.lda, args.ldb); - std::swap(args.batch_stride_A, args.batch_stride_B); - - return args; - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params { - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorB::Params params_B; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::OutputTileIterator::Params params_D; - typename Epilogue::TensorTileIterator::Params params_Tensor; - - typename EpilogueOutputOp::Params output_op; - - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; - - void* ptr_A; - void* ptr_B; - void* ptr_C; - void* ptr_D; - - void* ptr_Vector; - typename LayoutC::Stride::Index ldr; - - void* ptr_Tensor; - - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_C; - int64_t batch_stride_D; - int64_t batch_stride_Vector; - int64_t batch_stride_Tensor; - - int* semaphore; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : swizzle_log_tile(0), - params_A(0), - params_B(0), - params_C(0), - params_D(0), - batch_count(0), - gemm_k_size(0), - mode(cutlass::gemm::GemmUniversalMode::kGemm), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_D(nullptr), - ptr_Vector(nullptr), - ldr(0), - ptr_Tensor(nullptr), - batch_stride_A(0), - batch_stride_B(0), - batch_stride_C(0), - batch_stride_D(0), - batch_stride_Vector(0), - batch_stride_Tensor(0), - semaphore(nullptr) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, - cutlass::gemm::GemmCoord const& grid_tiled_shape, - int gemm_k_size, - void* workspace = nullptr) - : problem_size(args.problem_size), - grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), - params_A(args.lda), - params_B(args.ldb), - params_C(args.ldc), - params_D(args.ldd), - // Here we pass additional user args via args.epilogue - params_Tensor(args.ldt, args.epilogue), - output_op(args.epilogue), - mode(args.mode), - batch_count(args.batch_count), - gemm_k_size(gemm_k_size), - ptr_A(const_cast(args.ptr_A)), - ptr_B(const_cast(args.ptr_B)), - ptr_C(const_cast(args.ptr_C)), - ptr_D(args.ptr_D), - ptr_Vector(args.ptr_Vector), - ldr(args.ldr), - ptr_Tensor(args.ptr_Tensor), - - batch_stride_A(args.batch_stride_A), - batch_stride_B(args.batch_stride_B), - batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D), - batch_stride_Vector(args.batch_stride_Vector), - batch_stride_Tensor(args.batch_stride_Tensor), - - semaphore(static_cast(workspace)) - { - CUTLASS_TRACE_HOST( - "FusedL2NNWithFusedEpilogue::Params::Params() - problem_size: " << problem_size); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void*)this->ptr_Reduction); - CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void*)this->ptr_Tensor); - CUTLASS_TRACE_HOST(" ldr: " << this->ldr); - CUTLASS_TRACE_HOST(" ldt: " << args.ldt); - } - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr) - { - ptr_A = const_cast(args.ptr_A); - ptr_B = const_cast(args.ptr_B); - ptr_C = const_cast(args.ptr_C); - ptr_D = args.ptr_D; - - ptr_Vector = args.ptr_Vector; - ldr = args.ldr; - ptr_Tensor = args.ptr_Tensor; - - batch_stride_A = args.batch_stride_A; - batch_stride_B = args.batch_stride_B; - batch_stride_C = args.batch_stride_C; - batch_stride_D = args.batch_stride_D; - batch_stride_Vector = args.batch_stride_Vector; - batch_stride_Tensor = args.batch_stride_Tensor; - - output_op = args.epilogue; - - semaphore = static_cast(workspace); - - CUTLASS_TRACE_HOST("FusedL2NNWithFusedEpilogue::Params::update()"); - CUTLASS_TRACE_HOST(" ptr_Reduction: " << (void*)this->ptr_Reduction); - CUTLASS_TRACE_HOST(" ptr_Tensor: " << (void*)this->ptr_Tensor); - CUTLASS_TRACE_HOST(" ldr: " << this->ldr); - } - }; - - struct epilogue_SharedStorage { - typename Epilogue::SharedStorage epilogue; - typename Epilogue::TensorTileIterator::SharedStorage reduced_store; - }; - - /// Shared memory storage structure - union SharedStorage { - typename Mma::SharedStorage main_loop; - epilogue_SharedStorage epilogue_combined_store; - }; - - public: - // - // Methods - // - - CUTLASS_DEVICE - FusedL2NNWithFusedEpilogue() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - CUTLASS_TRACE_HOST("FusedL2NNWithFusedEpilogue::can_implement()"); - - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - bool isAMisaligned = false; - bool isBMisaligned = false; - bool isCMisaligned = false; - - if (platform::is_same::value) { - isAMisaligned = problem_size.k() % kAlignmentA; - } else if (platform::is_same::value) { - isAMisaligned = problem_size.m() % kAlignmentA; - } else if (platform::is_same>::value || - platform::is_same>::value) { - isAMisaligned = problem_size.k() % kAlignmentA; - } - - if (platform::is_same::value) { - isBMisaligned = problem_size.n() % kAlignmentB; - } else if (platform::is_same::value) { - isBMisaligned = problem_size.k() % kAlignmentB; - } else if (platform::is_same>::value || - platform::is_same>::value) { - isBMisaligned = problem_size.k() % kAlignmentB; - } - - if (platform::is_same::value) { - isCMisaligned = problem_size.n() % kAlignmentC; - } else if (platform::is_same::value) { - isCMisaligned = problem_size.m() % kAlignmentC; - } else if (platform::is_same>::value || - platform::is_same>::value) { - isCMisaligned = problem_size.n() % kAlignmentC; - } - - if (isAMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); - return Status::kErrorMisalignedOperand; - } - - if (isBMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); - return Status::kErrorMisalignedOperand; - } - - if (isCMisaligned) { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); - return Status::kErrorMisalignedOperand; - } - - CUTLASS_TRACE_HOST(" returning kSuccess"); - - return Status::kSuccess; - } - - static Status can_implement(Arguments const& args) { return can_implement(args.problem_size); } - - static size_t get_extra_workspace_size(Arguments const& args, - cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - return 0; - } - -#define SPLIT_K_ENABLED 1 - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = - threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || - params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { - return; - } - - int offset_k = 0; - int problem_size_k = params.problem_size.k(); - - ElementA* ptr_A = static_cast(params.ptr_A); - ElementB* ptr_B = static_cast(params.ptr_B); - -#if SPLIT_K_ENABLED - // - // Fetch pointers based on mode. - // - if (params.mode == GemmUniversalMode::kGemm || - params.mode == GemmUniversalMode::kGemmSplitKParallel) { - if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } - - offset_k = threadblock_tile_offset.k() * params.gemm_k_size; - } else if (params.mode == GemmUniversalMode::kBatched) { - ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; - ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; - } else if (params.mode == GemmUniversalMode::kArray) { - ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; - ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; - } -#endif - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - offset_k, - }; - - cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); - - typename Mma::IteratorB iterator_B( - params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // assume identity swizzle - MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.n() * Mma::Shape::kN); - - int block_idx = - threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - ElementC* ptr_C = static_cast(params.ptr_C); - ElementC* ptr_D = static_cast(params.ptr_D); - typename Epilogue::ElementTensor* ptr_Tensor = - static_cast(params.ptr_Tensor); - - // Define the reduction output pointer and move to the appropriate place - typename Epilogue::ElementVector* ptr_Vector = - static_cast(params.ptr_Vector); - - // - // Fetch pointers based on mode. - // - - // - // Special path when split-K not enabled. - // - - if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() == 1) { - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params.params_C, ptr_C, params.problem_size.mn(), thread_idx, threadblock_offset); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params.params_D, ptr_D, params.problem_size.mn(), thread_idx, threadblock_offset); - - // Additional tensor to load from - typename Epilogue::TensorTileIterator tensor_iterator( - shared_storage.epilogue_combined_store.reduced_store, - params.params_Tensor, - // Only the final block outputs Tensor - ptr_Tensor, - params.problem_size.mn(), - thread_idx, - threadblock_offset); - - // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue_combined_store.epilogue, thread_idx, warp_idx, lane_idx); - - // Move to appropriate location for this output tile - if (ptr_Vector) { - ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, - ptr_Vector, - iterator_D, - accumulators, - iterator_C, - tensor_iterator, - params.problem_size.mn(), - threadblock_offset); - - return; - } - - // - // Slower path when split-K or batching is needed - // - -#if SPLIT_K_ENABLED - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - if (params.mode == GemmUniversalMode::kGemm) { - // If performing a reduction via split-K, fetch the initial synchronization - if (params.grid_tiled_shape.k() > 1) { - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - } else if (params.mode == GemmUniversalMode::kGemmSplitKParallel) { - ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - } else if (params.mode == GemmUniversalMode::kBatched) { - ptr_C += threadblock_tile_offset.k() * params.batch_stride_C; - ptr_D += threadblock_tile_offset.k() * params.batch_stride_D; - if (ptr_Tensor) { ptr_Tensor += threadblock_tile_offset.k() * params.batch_stride_Tensor; } - if (ptr_Vector) { ptr_Vector += threadblock_tile_offset.k() * params.batch_stride_Vector; } - } else if (params.mode == GemmUniversalMode::kArray) { - ptr_C = static_cast(params.ptr_C)[threadblock_tile_offset.k()]; - ptr_D = static_cast(params.ptr_D)[threadblock_tile_offset.k()]; - if (ptr_Tensor) { - ptr_Tensor = static_cast( - params.ptr_Tensor)[threadblock_tile_offset.k()]; - } - if (ptr_Vector) { - ptr_Vector = static_cast( - params.ptr_Vector)[threadblock_tile_offset.k()]; - } - } -#endif - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params.params_C, ptr_C, params.problem_size.mn(), thread_idx, threadblock_offset); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params.params_D, ptr_D, params.problem_size.mn(), thread_idx, threadblock_offset); - - // Additional tensor to load from - typename Epilogue::TensorTileIterator tensor_iterator( - shared_storage.epilogue_combined_store.reduced_store, - params.params_Tensor, - // Only the final block outputs Tensor - ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && - (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) - ? nullptr - : ptr_Tensor, - params.problem_size.mn(), - thread_idx, - threadblock_offset); - - // Construct the epilogue - Epilogue epilogue( - shared_storage.epilogue_combined_store.epilogue, thread_idx, warp_idx, lane_idx); - -#if SPLIT_K_ENABLED - // Wait on the semaphore - this latency may have been covered by iterator construction - if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) { iterator_C = iterator_D; } - - semaphore.wait(threadblock_tile_offset.k()); - } -#endif - - // Move to appropriate location for this output tile - if (ptr_Vector) { - ptr_Vector += threadblock_offset.column() + threadblock_tile_offset.m() * params.ldr; - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, - // Only the final block uses Vector - ((params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) && - (params.grid_tiled_shape.k() != threadblock_tile_offset.k() + 1)) - ? nullptr - : ptr_Vector, - iterator_D, - accumulators, - iterator_C, - tensor_iterator, - params.problem_size.mn(), - threadblock_offset); - - // - // Release the semaphore - // - -#if SPLIT_K_ENABLED - if ((params.mode == GemmUniversalMode::kGemm) && params.grid_tiled_shape.k() > 1) { - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } else { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - semaphore.release(lock); - } -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh index 81259cdaea..679aab72a9 100644 --- a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh @@ -84,12 +84,21 @@ struct FixConnectivitiesRedOp { DI void init(value_t* out, value_t maxVal) const { *out = maxVal; } DI void init(KVP* out, value_t maxVal) const { - out->key = -1; out->value = maxVal; } DI void init_key(value_t& out, value_idx idx) const { return; } DI void init_key(KVP& out, value_idx idx) const { out.key = idx; } + + DI value_t get_value(KVP& out) const + { + return out.value;; + } + + DI value_t get_value(value_t& out) const + { + return out; + } }; /** From 42b389086b0bef4d97dc83d1116611b18f05b9c1 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 28 Apr 2023 04:24:00 -0700 Subject: [PATCH 27/48] fix formatting issues --- .../custom_epilogue_with_broadcast.h | 465 +++++++++--------- .../detail/fused_distance_nn/cutlass_base.cuh | 64 +-- .../detail/fused_distance_nn/epilogue.cuh | 29 +- .../epilogue_elementwise.cuh | 6 +- .../distance/detail/fused_distance_nn/gemm.h | 54 +- .../fused_distance_nn/persistent_gemm.h | 232 ++++----- .../predicated_tile_iterator_normvec_smem.h | 64 ++- .../predicated_tile_iterator_reduced_vec.h | 129 ++--- .../raft/distance/detail/fused_l2_nn.cuh | 20 +- .../detail/predicated_tile_iterator_normvec.h | 3 +- .../neighbors/detail/connect_components.cuh | 13 +- cpp/include/raft/util/cutlass_utils.cuh | 2 +- cpp/test/CMakeLists.txt | 10 +- 13 files changed, 530 insertions(+), 561 deletions(-) mode change 100755 => 100644 cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h mode change 100755 => 100644 cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h old mode 100755 new mode 100644 index 39fa4d3aea..d90d75a4b4 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -47,16 +47,16 @@ #include #endif -#include "cutlass/cutlass.h" -#include "cutlass/array.h" -#include "cutlass/numeric_types.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/tensor_coord.h" #include "cutlass/aligned_buffer.h" -#include "cutlass/functional.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" #include "cutlass/fast_math.h" -#include "cutlass/layout/vector.h" +#include "cutlass/functional.h" #include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" #include "cutlass/gemm/gemm.h" @@ -78,30 +78,27 @@ namespace threadblock { /// This base class is meant to define the concept required of the /// EpilogueWithBroadcast::OutputOp -template < - typename ElementC_, - typename ElementAccumulator_, - typename ElementCompute_, - typename ElementZ_, - typename ElementT_, - int ElementsPerAccess, - bool StoreZ = true, - bool StoreT = true -> +template struct EpilogueWithBroadcastOpBaseCustom { - - using ElementOutput = ElementC_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using ElementZ = ElementZ_; - using ElementT = ElementT_; + using ElementOutput = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; static int const kElementsPerAccess = ElementsPerAccess; using FragmentAccumulator = Array; - using FragmentCompute = Array; - using FragmentC = Array; - using FragmentZ = Array; - using FragmentT = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using FragmentT = Array; /// If true, the 'Z' tensor is stored static bool const kStoreZ = StoreZ; @@ -110,42 +107,38 @@ struct EpilogueWithBroadcastOpBaseCustom { static bool const kStoreT = StoreT; /// Parameters structure - required - struct Params { }; + struct Params {}; // // Methods // /// Constructor from Params - EpilogueWithBroadcastOpBaseCustom(Params const ¶ms_) { } + EpilogueWithBroadcastOpBaseCustom(Params const& params_) {} - /// Determine if the source is needed. May return false if - bool is_source_needed() const { - return true; - } + /// Determine if the source is needed. May return false if + bool is_source_needed() const { return true; } CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) { } + void set_k_partition(int k_partition, int k_partition_count) {} /// Applies the operation when is_source_needed() is true CUTLASS_HOST_DEVICE - void operator()( - FragmentZ &frag_Z, - FragmentT &frag_T, - FragmentAccumulator const &AB, - FragmentC const &frag_C, - FragmentCompute const &V) const { - + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentC const& frag_C, + FragmentCompute const& V) const + { } /// Applies the operation when is_source_needed() is false CUTLASS_HOST_DEVICE - void operator()( - FragmentZ &frag_Z, - FragmentT &frag_T, - FragmentAccumulator const &AB, - FragmentCompute const &V) const { - + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentCompute const& V) const + { } }; @@ -160,62 +153,57 @@ struct EpilogueWithBroadcastOpBaseCustom { /// /// if (ElementwiseOp::kStoreZ) { /// store(converted_u); -/// } +/// } /// /// if (ElementwiseOp::kStoreT) { /// store(v); -/// } +/// } /// template < - typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) - typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) - int PartitionsK, ///< Number of partitions of the K dimension - typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) - typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) - typename ElementVector_, ///< Pointer to broadcast vector - typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators - typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM - typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM - typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp - typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) - int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity - int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large - (!IsEpilogueFunctorHeavy::value) -> -class EpilogueWithBroadcastCustom : - public EpilogueBase< - Shape_, - typename WarpMmaOperator_::Shape, - PartitionsK, - AccumulatorFragmentIterator_, - WarpTileIterator_, - Padding_, - FragmentsPerPartition> { - -public: - - using Base = EpilogueBase< - Shape_, - typename WarpMmaOperator_::Shape, - PartitionsK, - AccumulatorFragmentIterator_, - WarpTileIterator_, - Padding_, - FragmentsPerPartition>; - - using Shape = Shape_; - using WarpMmaOperator = WarpMmaOperator_; - static int const kPartitionsK = PartitionsK; - using OutputTileIterator = OutputTileIterator_; - using TensorTileIterator = TensorTileIterator_; - using ElementVector = ElementVector_; + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) + typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) + typename ElementVector_, ///< Pointer to broadcast vector + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: + ///< MatrixShape) + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value)> +class EpilogueWithBroadcastCustom : public EpilogueBase { + public: + using Base = EpilogueBase; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using TensorTileIterator = TensorTileIterator_; + using ElementVector = ElementVector_; using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; - using WarpTileIterator = WarpTileIterator_; - using SharedLoadIterator = SharedLoadIterator_; - using OutputOp = OutputOp_; - using Padding = Padding_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; - using Layout = layout::RowMajor; + using Layout = layout::RowMajor; using LongIndex = typename Layout::LongIndex; /// The complete warp-level accumulator tile @@ -234,9 +222,8 @@ class EpilogueWithBroadcastCustom : using ThreadMap = typename OutputTileIterator::ThreadMap; /// Fragment object used to store the broadcast values - using BroadcastFragment = Array< - ElementCompute, - ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess>; + using BroadcastFragment = + Array; /// Output element using ElementOutput = typename OutputTileIterator::Element; @@ -257,40 +244,43 @@ class EpilogueWithBroadcastCustom : using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; /// Array type used to output - using OutputAccessType = Array< - typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + using OutputAccessType = + Array; /// Array type used by output functor - using AccumulatorAccessType = Array; + using AccumulatorAccessType = + Array; /// Array type used by output functor using ComputeAccessType = Array; /// Tensor access type using TensorAccessType = Array; - + /// Number of warps using WarpCount = typename Base::WarpCount; /// Shared memory allocation from epilogue base class using BaseSharedStorage = typename Base::SharedStorage; - static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemTiles = + Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; /// Used for the broadcast struct BroadcastDetail { - /// Number of threads per warp static int const kWarpSize = 32; static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; /// Number of distinct scalar column indices handled by each thread - static int const kColumnsPerThread = ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; + static int const kColumnsPerThread = + ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; /// Number of distinct scalar row indices handled by each thread - static int const kRowsPerThread = ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; + static int const kRowsPerThread = + ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; /// Number of threads per threadblock static int const kThreadCount = kWarpSize * WarpCount::kCount; @@ -298,21 +288,21 @@ class EpilogueWithBroadcastCustom : /// Number of distinct threads per row of output tile static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); - /// Number of distinct threads which must be reduced during the final reduction phase within the threadblock. + /// Number of distinct threads which must be reduced during the final reduction phase within the + /// threadblock. static int const kThreadRows = kThreadCount / kThreadsPerRow; /// I'm not sure what I meant here. - static int const kThreadAccessesPerRow = const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); + static int const kThreadAccessesPerRow = + const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); - /// Shape of the shared memory allocation for the epilogue - using StorageShape = MatrixShape< - kThreadRows, - Shape::kN - >; + /// Shape of the shared memory allocation for the epilogue + using StorageShape = MatrixShape; /// Debug printing CUTLASS_DEVICE - static void print() { + static void print() + { #if 0 printf("BroadcastDetail {\n"); printf( @@ -340,114 +330,109 @@ class EpilogueWithBroadcastCustom : }; CUTLASS_HOST_DEVICE - SharedStorage() { } + SharedStorage() {} }; -public: - - - // static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, + public: + // static_assert(SharedLoadIterator::Fragment::kElements == + // OutputTileIterator::Fragment::kElements, // "Mismatch between shared load iterator and output tile iterator."); static_assert(SharedLoadIterator::Fragment::kElements == TensorTileIterator::Fragment::kElements, - "Mismatch between shared load iterator and output tile iterator."); - - static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + "Mismatch between shared load iterator and output tile iterator."); - static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), - "Divisibility"); + static_assert(OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); -private: + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + private: /// Loads fragment from shared memory aligned with output tensor SharedLoadIterator shared_load_iterator_; /// Thread index within the threadblock int thread_idx_; -public: - + public: /// Constructor CUTLASS_DEVICE - EpilogueWithBroadcastCustom( - SharedStorage &shared_storage, ///< Shared storage object - int thread_idx, ///< ID of a thread within the threadblock - int warp_idx, ///< ID of warp within threadblock - int lane_idx ///< Id of thread within warp - ): - Base(shared_storage.base, thread_idx, warp_idx, lane_idx), - shared_load_iterator_(shared_storage.base.reference(), thread_idx), - thread_idx_(thread_idx) + EpilogueWithBroadcastCustom(SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage.base, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.base.reference(), thread_idx), + thread_idx_(thread_idx) { - } /// Streams the result to global memory CUTLASS_DEVICE void operator()( - OutputOp const &output_op, ///< Output operator - ElementVector const * broadcast_ptr, ///< Broadcast vector - AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix - TensorTileIterator tensor_iterator, ///< Threadblock tile iterator for additional tensor operand - MatrixCoord const &problem_size = ///< Problem size needed to guard against out-of-bounds accesses - MatrixCoord(Shape::kM, Shape::kN), - MatrixCoord const &threadblock_offset = ///< Threadblock's initial offset within the problem size space - MatrixCoord()) { - + OutputOp const& output_op, ///< Output operator + ElementVector const* broadcast_ptr, ///< Broadcast vector + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix + TensorTileIterator + tensor_iterator, ///< Threadblock tile iterator for additional tensor operand + MatrixCoord const& + problem_size = ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord(Shape::kM, Shape::kN), + MatrixCoord const& + threadblock_offset = ///< Threadblock's initial offset within the problem size space + MatrixCoord()) + { BroadcastFragment broadcast_fragment; load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); compute_source_needed_( - output_op, - broadcast_fragment, - accumulators, - source_iterator, - tensor_iterator); + output_op, broadcast_fragment, accumulators, source_iterator, tensor_iterator); } -private: - + private: CUTLASS_DEVICE void load_broadcast_fragment_( - BroadcastFragment & broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - ElementVector const * broadcast_ptr, ///< Broadcast vector - MatrixCoord const &problem_size, ///< Problem size needed to guard against out-of-bounds accesses - MatrixCoord const &threadblock_offset ///< Threadblock's initial offset within the problem size space - ) { - + BroadcastFragment& + broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + ElementVector const* broadcast_ptr, ///< Broadcast vector + MatrixCoord const& + problem_size, ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord const& + threadblock_offset ///< Threadblock's initial offset within the problem size space + ) + { broadcast_fragment.clear(); - + // If no pointer is supplied, set with all zeros and avoid memory accesses - if (!broadcast_ptr) { - return; - } + if (!broadcast_ptr) { return; } int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); int thread_column_idx = threadblock_offset.column() + thread_initial_column; broadcast_ptr += thread_initial_column; - NumericArrayConverter converter; - using AccessType = AlignedArray; + NumericArrayConverter + converter; + using AccessType = AlignedArray; using ComputeFragmentType = Array; - ComputeFragmentType *frag_ptr = reinterpret_cast(&broadcast_fragment); + ComputeFragmentType* frag_ptr = reinterpret_cast(&broadcast_fragment); CUTLASS_PRAGMA_UNROLL for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { - AccessType loaded; loaded.clear(); if (thread_column_idx < problem_size.column()) { - loaded = *reinterpret_cast(broadcast_ptr); + loaded = *reinterpret_cast(broadcast_ptr); } ComputeFragmentType cvt = converter(loaded); - frag_ptr[j] = cvt; + frag_ptr[j] = cvt; thread_column_idx += ThreadMap::Delta::kColumn; broadcast_ptr += ThreadMap::Delta::kColumn; @@ -461,7 +446,8 @@ class EpilogueWithBroadcastCustom : struct acc2smem_source_not_needed> { template CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator &warp_tile_iterator) { + WarpTileIterator& warp_tile_iterator) + { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; @@ -488,11 +474,12 @@ class EpilogueWithBroadcastCustom : CUTLASS_DEVICE static void push(size_t pos, - AccumulatorFragmentIterator const &iterator_begin, - WarpTileIterator &warp_tile_iterator) { + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) + { int dummy[] = { - (pos == (Seq * Base::kFragmentsPerIteration)) && - (helper(iterator_begin, warp_tile_iterator), 0)...}; + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; CUTLASS_UNUSED(dummy[0]); } @@ -501,12 +488,14 @@ class EpilogueWithBroadcastCustom : /// Streams the result to global memory CUTLASS_DEVICE void compute_source_not_needed_( - OutputOp const &output_op, ///< Output operator - BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand - ) { + OutputOp const& output_op, ///< Output operator + BroadcastFragment const& + broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) + { #if 0 // // Iterator over warp-level accumulator fragment @@ -519,7 +508,8 @@ class EpilogueWithBroadcastCustom : // // CUTLASS_PRAGMA_UNROLL - #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ + : 1) for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) { // @@ -602,16 +592,15 @@ class EpilogueWithBroadcastCustom : #endif } - - template + template struct acc2smem_source_needed; template struct acc2smem_source_needed> { - template - CUTLASS_DEVICE - static void helper(AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator &warp_tile_iterator) { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) + { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < Advance; i++) { ++accum_fragment_iterator; @@ -624,23 +613,25 @@ class EpilogueWithBroadcastCustom : CUTLASS_DEVICE static void push(size_t pos, - AccumulatorFragmentIterator const &iterator_begin, - WarpTileIterator &warp_tile_iterator) { + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) + { int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; } }; - /// Streams the result to global memory CUTLASS_DEVICE void compute_source_needed_( - OutputOp const &output_op, ///< Output operator - BroadcastFragment const &broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - AccumulatorTile const &accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) - TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand - ) { - + OutputOp const& output_op, ///< Output operator + BroadcastFragment const& + broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator + source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) + { typename OutputTileIterator::Fragment source_fragment; source_fragment.clear(); @@ -652,19 +643,18 @@ class EpilogueWithBroadcastCustom : // // Iterate over accumulator tile - // + // - #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { - // // Convert and store fragment // - + //__syncthreads(); acc2smem_source_needed>::push( - iter, accum_fragment_iterator, this->warp_tile_iterator_); + iter, accum_fragment_iterator, this->warp_tile_iterator_); __syncthreads(); @@ -706,12 +696,7 @@ class EpilogueWithBroadcastCustom : ++source_iterator; apply_output_operator_( - frag_T, - output_op, - aligned_accum_fragment[0], - source_fragment, - broadcast_fragment); - + frag_T, output_op, aligned_accum_fragment[0], source_fragment, broadcast_fragment); // // Conditionally store fragments @@ -725,57 +710,53 @@ class EpilogueWithBroadcastCustom : /// Helper to invoke the output functor over each vector of output CUTLASS_DEVICE - void apply_output_operator_( - typename TensorTileIterator::Fragment &frag_T, - OutputOp const &output_op, - typename SharedLoadIterator::Fragment const &frag_AB, - typename OutputTileIterator::Fragment const &frag_C, - BroadcastFragment const &frag_Broadcast) { - - using AccessTypeT = Array; + void apply_output_operator_(typename TensorTileIterator::Fragment& frag_T, + OutputOp const& output_op, + typename SharedLoadIterator::Fragment const& frag_AB, + typename OutputTileIterator::Fragment const& frag_C, + BroadcastFragment const& frag_Broadcast) + { + using AccessTypeT = Array; using AccessTypeBroadcast = Array; - AccessTypeT *frag_T_ptr = reinterpret_cast(&frag_T); - - AccumulatorAccessType const *frag_AB_ptr = - reinterpret_cast(&frag_AB); + AccessTypeT* frag_T_ptr = reinterpret_cast(&frag_T); + + AccumulatorAccessType const* frag_AB_ptr = + reinterpret_cast(&frag_AB); - OutputAccessType const *frag_C_ptr = - reinterpret_cast(&frag_C); + OutputAccessType const* frag_C_ptr = reinterpret_cast(&frag_C); - AccessTypeBroadcast const *frag_Broadcast_ptr = - reinterpret_cast(&frag_Broadcast); + AccessTypeBroadcast const* frag_Broadcast_ptr = + reinterpret_cast(&frag_Broadcast); - int const kOutputOpIterations = + int const kOutputOpIterations = TensorTileIterator::Fragment::kElements / TensorTileIterator::kElementsPerAccess; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kOutputOpIterations; ++i) { - - output_op( - frag_T_ptr[i], - frag_AB_ptr[i], - frag_C_ptr[(i / ThreadMap::Iterations::kColumn)], - frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); + output_op(frag_T_ptr[i], + frag_AB_ptr[i], + frag_C_ptr[(i / ThreadMap::Iterations::kColumn)], + frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); } } /// Helper to invoke the output functor over each vector of output CUTLASS_DEVICE void apply_output_operator_source_not_needed_( - typename OutputTileIterator::Fragment &frag_Z, - typename TensorTileIterator::Fragment &frag_T, - OutputOp const &output_op, - typename SharedLoadIterator::Fragment const &frag_AB, - BroadcastFragment const &frag_Broadcast) { - + typename OutputTileIterator::Fragment& frag_Z, + typename TensorTileIterator::Fragment& frag_T, + OutputOp const& output_op, + typename SharedLoadIterator::Fragment const& frag_AB, + BroadcastFragment const& frag_Broadcast) + { } }; //////////////////////////////////////////////////////////////////////////////// -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass //////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh index 8bcf006528..2b742f9ee6 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh @@ -25,11 +25,11 @@ #define cutlass raft_cutlass #endif -#include #include #include -#include #include +#include +#include #include #include @@ -38,8 +38,8 @@ #include #include -#include #include +#include namespace raft { namespace distance { @@ -72,21 +72,22 @@ void cutlassFusedDistanceNN(const DataT* x, KVPReduceOpT pairRedOp, cudaStream_t stream) { - using EpilogueOutputOp = - cutlass::epilogue::thread::FusedDistanceNNEpilogueElementwise::value, - 1, // Elements per access 1 - DistanceFn, - CGReduceOpT, - ReduceOpT, - KVPReduceOpT>; + using EpilogueOutputOp = cutlass::epilogue::thread::FusedDistanceNNEpilogueElementwise< + DataT, // ElementC_ + AccT, // ElementAccumulator_ + DataT, // ElementCompute_ + AccT, // ElementZ_ + OutT, // ElementT_ + // 128 / cutlass::sizeof_bits::value, + 1, // Elements per access 1 + DistanceFn, + CGReduceOpT, + ReduceOpT, + KVPReduceOpT>; constexpr int batch_count = 1; - typename EpilogueOutputOp::Params epilog_op_param(dist_op, cg_reduce_op, redOp, pairRedOp, mutexes); + typename EpilogueOutputOp::Params epilog_op_param( + dist_op, cg_reduce_op, redOp, pairRedOp, mutexes); // Number of pipelines you want to use constexpr int NumStages = 3; @@ -100,34 +101,33 @@ void cutlassFusedDistanceNN(const DataT* x, using fusedDistanceNNKernel = typename cutlass::gemm::kernel::FusedDistanceNNGemm::GemmKernel; - + Alignment, + DataT, + Alignment, + AccT, + AccT, + EpilogueOutputOp, + NumStages, // Number of pipeline stages + isRowMajor>::GemmKernel; using fusedDistanceNN = cutlass::gemm::device::GemmGrouped; - int num_blocks_per_sm = fusedDistanceNN::maximum_active_blocks(); - int num_sms = raft::getMultiProcessorCount(); - int num_blocks = num_blocks_per_sm * num_sms; + int num_blocks_per_sm = fusedDistanceNN::maximum_active_blocks(); + int num_sms = raft::getMultiProcessorCount(); + int num_blocks = num_blocks_per_sm * num_sms; constexpr int mmaShapeM = fusedDistanceNNKernel::Mma::Shape::kM; auto thread_blocks = std::max(num_blocks, int((problem_size.m() - 1 + mmaShapeM) / mmaShapeM)); typename fusedDistanceNN::Arguments arguments{ problem_size, - batch_count, // num of problems. + batch_count, // num of problems. thread_blocks, epilog_op_param, x, y, - xn, // C matrix eq vector param, which here is A norm - (DataT*)yn, // this is broadcast vec, which is required to be non-const param - dOutput, // Output distance matrix + xn, // C matrix eq vector param, which here is A norm + (DataT*)yn, // this is broadcast vec, which is required to be non-const param + dOutput, // Output distance matrix (int64_t)lda, // stride A (int64_t)ldb, // stride B (int64_t)1, // stride A norm diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh index 96c6697c02..7feaea1f02 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh @@ -71,8 +71,8 @@ struct FusedDistanceNNEpilogue { // // Stores the result z = (y = GEMM(A, B, C), broadcast) // - using RowNormTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorNormVecSmem< - typename Base::OutputTileThreadMap, ElementOutput, LayoutT>; + using RowNormTileIterator = cutlass::epilogue::threadblock:: + PredicatedTileIteratorNormVecSmem; // // Additional tensor tile iterator - stores t = Elementwise(z) @@ -84,18 +84,19 @@ struct FusedDistanceNNEpilogue { typename OutputOp::Params>; /// Define the epilogue - using Epilogue = cutlass::epilogue::threadblock::EpilogueWithBroadcastCustom; + using Epilogue = cutlass::epilogue::threadblock::EpilogueWithBroadcastCustom< + Shape, + WarpMmaTensorOp, + PartitionsK, + RowNormTileIterator, + OutputTileIterator, + ElementVector, + typename Base::AccumulatorFragmentIterator, + typename Base::WarpTileIterator, + typename Base::SharedLoadIterator, + OutputOp, + typename Base::Padding, + Base::kFragmentsPerIteration>; }; } // namespace threadblock diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh index 00fafe5fa1..4058406bc1 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -73,8 +73,8 @@ class FusedDistanceNNEpilogueElementwise { using FragmentC = Array; using FragmentZ = Array; using OutValT = typename CGReduceOp::AccTypeT; - //using FragmentT = Array; - using FragmentT = Array; + // using FragmentT = Array; + using FragmentT = Array; using FragmentOutput = FragmentZ; @@ -164,7 +164,7 @@ class FusedDistanceNNEpilogueElementwise { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerAccess; ++i) { ElementCompute res_Z = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); - //red_op.init(&frag_T[i], res_Z); + // red_op.init(&frag_T[i], res_Z); frag_T[i] = res_Z; } } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h index 7b9858d9bc..feb078fcc0 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h @@ -22,9 +22,9 @@ #include #include -//#include -#include +// #include #include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -60,15 +60,18 @@ struct FusedDistanceNNGemm { /// Threadblock-level tile size (concept: GemmShape) // <- threadblock tile M = 32, N = 64, K = 16 - using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 16>; // this is more performant for grouped GEMM - //using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; // this shape has high occupancy but less perf + using ThreadblockShape = + cutlass::gemm::GemmShape<32, 256, 16>; // this is more performant for grouped GEMM + // using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; // this shape has high + // occupancy but less perf /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute // <- warp tile M = 64, N = 64, K = 16 - using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // this is more performant for grouped GEMM - //using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // this shape has high occupancy but less perf - + using WarpShape = + cutlass::gemm::GemmShape<32, 64, 16>; // this is more performant for grouped GEMM + // using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // this shape has high occupancy but + // less perf /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op @@ -77,7 +80,7 @@ struct FusedDistanceNNGemm { /// Operation performed by GEMM using Operator = cutlass::arch::OpMultiplyAddFastF32; - //using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs + // using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU // SM @@ -137,10 +140,11 @@ struct FusedDistanceNNGemm { NormXLayout, GemmBase::Epilogue::kElementsPerAccess>::Epilogue; - // Compose the GEMM kernel - using GemmKernel = FusedDistanceNNPersistent; + using GemmKernel = FusedDistanceNNPersistent; }; template < @@ -159,23 +163,23 @@ template < /// data layout row/column major of inputs bool isRowMajor> struct FusedDistanceNNGemm { + kAlignmentA, + double, + kAlignmentB, + ElementC_, + ElementAccumulator, + EpilogueOutputOp, + Stages, + isRowMajor> { // Threadblock-level tile size (concept: GemmShape) // <- threadblock tile M = 64, N = 64, K = 16 using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; - //using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; + // using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute // <- warp tile M = 32, N = 32, K = 16 using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; - //using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; + // using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; @@ -193,7 +197,6 @@ struct FusedDistanceNNGemm; - /// data layout for final output matrix. // we keep this same layout even for column major inputs using LayoutOutput = cutlass::layout::RowMajor; @@ -242,9 +245,10 @@ struct FusedDistanceNNGemm::Epilogue; // Compose the GEMM kernel - using GemmKernel = FusedDistanceNNPersistent; - + using GemmKernel = FusedDistanceNNPersistent; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index b08cc0a9ab..cca25c0cdd 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -122,8 +122,8 @@ struct FusedDistanceNNPersistent { struct temp_problem_visitor { int problem_count; - - CUTLASS_HOST_DEVICE temp_problem_visitor() : problem_count(0) {}; + + CUTLASS_HOST_DEVICE temp_problem_visitor() : problem_count(0){}; CUTLASS_HOST_DEVICE temp_problem_visitor(int problem_count_) : problem_count(problem_count_){}; }; @@ -160,7 +160,7 @@ struct FusedDistanceNNPersistent { /// Default ctor CUTLASS_HOST_DEVICE Arguments() - : //problem_count(0), + : // problem_count(0), threadblock_count(0), ptr_A(nullptr), ptr_B(nullptr), @@ -207,8 +207,6 @@ struct FusedDistanceNNPersistent { { problem_visitor.problem_count = problem_count; } - - }; // @@ -217,7 +215,7 @@ struct FusedDistanceNNPersistent { /// Parameters structure struct Params { - //typename ProblemVisitor::Params problem_visitor; + // typename ProblemVisitor::Params problem_visitor; temp_problem_visitor problem_visitor; int threadblock_count; @@ -257,7 +255,7 @@ struct FusedDistanceNNPersistent { lda(0), ldb(0), ldc(0), - ldt(0) + ldt(0) { } @@ -290,9 +288,9 @@ struct FusedDistanceNNPersistent { { threadblock_count = args.threadblock_count; output_op = args.output_op; - ptr_A = const_cast(args.ptr_A); - ptr_B = const_cast(args.ptr_B); - ptr_C = const_cast(args.ptr_C); + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); ptr_Vector = args.ptr_Vector; ptr_Tensor = args.ptr_Tensor; lda = args.lda; @@ -304,7 +302,6 @@ struct FusedDistanceNNPersistent { } }; - /// Shared memory storage structure struct SharedStorage { union { @@ -339,20 +336,19 @@ struct FusedDistanceNNPersistent { } CUTLASS_DEVICE - static uint32_t tile_count(const cutlass::MatrixCoord& grid) { + static uint32_t tile_count(const cutlass::MatrixCoord& grid) + { return grid.row() * grid.column(); } - /// Get the grid shape + /// Get the grid shape CUTLASS_DEVICE - static cutlass::MatrixCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { - - return cutlass::MatrixCoord( - ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), - ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN)); + static cutlass::MatrixCoord grid_shape(const cutlass::gemm::GemmCoord& problem) + { + return cutlass::MatrixCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN)); } - /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const& params, SharedStorage& shared_storage) @@ -369,16 +365,18 @@ struct FusedDistanceNNPersistent { using LayoutC = typename Epilogue::OutputTileIterator::Layout; using ElementOut = typename Epilogue::TensorTileIterator::Element; using LongIndexOut = typename Epilogue::TensorTileIterator::LongIndex; - using OutValTy = typename Epilogue::TensorTileIterator::OutValT; + using OutValTy = typename Epilogue::TensorTileIterator::OutValT; - const GemmCoord& problem_size = params.problem_size; - const uint32_t problem_chunk = (tile_count(grid_shape(problem_size)) - 1 + gridDim.x) / gridDim.x; + const GemmCoord& problem_size = params.problem_size; + const uint32_t problem_chunk = + (tile_count(grid_shape(problem_size)) - 1 + gridDim.x) / gridDim.x; const uint32_t problem_chunk_end = blockIdx.x * problem_chunk + problem_chunk; - const auto grid_shape_ = grid_shape(problem_size); - typename LayoutB::Index column = ((blockIdx.x * problem_chunk) % grid_shape_.column()) * Mma::Shape::kN; + const auto grid_shape_ = grid_shape(problem_size); + typename LayoutB::Index column = + ((blockIdx.x * problem_chunk) % grid_shape_.column()) * Mma::Shape::kN; { ElementOut* shared_elem_arr_ = shared_storage.reduced_store.data(); - constexpr auto maxVal_ = std::numeric_limits::max(); + constexpr auto maxVal_ = std::numeric_limits::max(); if (column) { for (int row = threadIdx.x; row < Mma::Shape::kM; row += blockDim.x) { @@ -388,124 +386,126 @@ struct FusedDistanceNNPersistent { } { - ElementC* shared_elem_arr = shared_storage.rownorm_store.data(); - if (column) { - typename LayoutB::Index row = ((blockIdx.x * problem_chunk) / grid_shape_.column()) * Mma::Shape::kM; - - uint8_t* first_tile_byte_pointer_ = reinterpret_cast(params.ptr_C) + - typename LayoutB::LongIndex(row) * typename LayoutB::LongIndex(sizeof(ElementC)); - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - - for (int row_local = threadIdx.x ; row_local < Mma::Shape::kM; row_local += blockDim.x) { - bool guard = (row + row_local) < problem_size.m(); - cutlass::arch::cp_async(shared_elem_arr + row_local, gmem_ptr + row_local, guard); - cutlass::arch::cp_async_wait<0>(); - } + ElementC* shared_elem_arr = shared_storage.rownorm_store.data(); + if (column) { + typename LayoutB::Index row = + ((blockIdx.x * problem_chunk) / grid_shape_.column()) * Mma::Shape::kM; + + uint8_t* first_tile_byte_pointer_ = + reinterpret_cast(params.ptr_C) + + typename LayoutB::LongIndex(row) * typename LayoutB::LongIndex(sizeof(ElementC)); + auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + for (int row_local = threadIdx.x; row_local < Mma::Shape::kM; row_local += blockDim.x) { + bool guard = (row + row_local) < problem_size.m(); + cutlass::arch::cp_async( + shared_elem_arr + row_local, gmem_ptr + row_local, guard); + cutlass::arch::cp_async_wait<0>(); } + } } // Outer 'persistent' loop to iterate over tiles for (uint32_t tile_idx = blockIdx.x * problem_chunk; tile_idx < problem_chunk_end; tile_idx++) { + const auto grid_shape_ = grid_shape(problem_size); + cutlass::MatrixCoord threadblock_offset( + int(tile_idx / grid_shape_.column()) * Mma::Shape::kM, + int(tile_idx % grid_shape_.column()) * Mma::Shape::kN); - const auto grid_shape_ = grid_shape(problem_size); - cutlass::MatrixCoord threadblock_offset( - int(tile_idx / grid_shape_.column()) * Mma::Shape::kM, - int(tile_idx % grid_shape_.column()) * Mma::Shape::kN); + const bool isNextTile = ((tile_idx + 1) < problem_chunk_end); + const bool doesRowChange = + ((threadblock_offset.column() + Mma::Shape::kN) >= problem_size.n()); + const bool do_gmem_reduce = (doesRowChange || !isNextTile) ? true : false; - const bool isNextTile = ((tile_idx + 1) < problem_chunk_end); - const bool doesRowChange = ((threadblock_offset.column() + Mma::Shape::kN) >= problem_size.n()); - const bool do_gmem_reduce = (doesRowChange || !isNextTile) ? true : false; + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); - ElementA* ptr_A = static_cast(params.ptr_A); - ElementB* ptr_B = static_cast(params.ptr_B); + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{threadblock_offset.row(), 0}; + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.column()}; - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{threadblock_offset.row(), 0}; - cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.column()}; + // Compute position within threadblock + int thread_idx = threadIdx.x; - // Compute position within threadblock - int thread_idx = threadIdx.x; + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size.k(), problem_size.n()}, thread_idx, tb_offset_B); - typename Mma::IteratorB iterator_B( - params.params_B, ptr_B, {problem_size.k(), problem_size.n()}, thread_idx, tb_offset_B); + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; - int lane_idx = threadIdx.x % 32; + // + // Matrix multiply phase + // - // - // Matrix multiply phase - // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + typename Mma::FragmentC accumulators; - typename Mma::FragmentC accumulators; + accumulators.clear(); + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - accumulators.clear(); - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + // Wait for all threads to finish their epilogue phases from the previous tile. + //__syncthreads(); - // Wait for all threads to finish their epilogue phases from the previous tile. - //__syncthreads(); + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + // + // Epilogue + // - // - // Epilogue - // + EpilogueOutputOp output_op(params.output_op); - EpilogueOutputOp output_op(params.output_op); - - ElementC* ptr_C = static_cast(params.ptr_C); - typename Epilogue::ElementTensor* ptr_Tensor = + ElementC* ptr_C = static_cast(params.ptr_C); + typename Epilogue::ElementTensor* ptr_Tensor = static_cast(params.ptr_Tensor); // Define the reduction output pointer and move to the appropriate place - typename Epilogue::ElementVector* ptr_Vector = + typename Epilogue::ElementVector* ptr_Vector = static_cast(params.ptr_Vector); - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_rownorm( - shared_storage.rownorm_store, - params.params_C, ptr_C, problem_size.mn(), thread_idx, - threadblock_offset); - - // Additional tensor to load from - typename Epilogue::TensorTileIterator tensor_iterator( - shared_storage.reduced_store, - params.params_Tensor, - // Only the final block outputs Tensor - ptr_Tensor, - problem_size.mn(), - thread_idx, - do_gmem_reduce, - threadblock_offset); - - Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - // Move to appropriate location for this output tile - if (ptr_Vector) { - ptr_Vector += threadblock_offset.column(); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, - ptr_Vector, - //iterator_D, - accumulators, - iterator_rownorm, - tensor_iterator, - problem_size.mn(), - threadblock_offset); + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_rownorm(shared_storage.rownorm_store, + params.params_C, + ptr_C, + problem_size.mn(), + thread_idx, + threadblock_offset); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator(shared_storage.reduced_store, + params.params_Tensor, + // Only the final block outputs Tensor + ptr_Tensor, + problem_size.mn(), + thread_idx, + do_gmem_reduce, + threadblock_offset); + + Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + // Move to appropriate location for this output tile + if (ptr_Vector) { ptr_Vector += threadblock_offset.column(); } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + ptr_Vector, + // iterator_D, + accumulators, + iterator_rownorm, + tensor_iterator, + problem_size.mn(), + threadblock_offset); } } }; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h old mode 100755 new mode 100644 index 20e44521b9..86d5e688a8 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -58,8 +58,8 @@ namespace threadblock { /// /// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator /// -template @@ -95,9 +95,8 @@ class PredicatedTileIteratorNormVecSmem { static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); using Fragment = Array; + ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; /// Memory access size using AccessType = AlignedArray; @@ -191,13 +190,13 @@ class PredicatedTileIteratorNormVecSmem { /// Byte-level pointer uint8_t* byte_pointer_; - //uint8_t* first_tile_byte_pointer_; + // uint8_t* first_tile_byte_pointer_; /// Array of boolean values to contain steady-state predicates Mask mask_; /// Extent of the matrix tile in rows Index extent_row_; - //Index block_start_row_first_tile_; + // Index block_start_row_first_tile_; /// Extent of the matrix tile in rows Index extent_column_; @@ -214,7 +213,6 @@ class PredicatedTileIteratorNormVecSmem { /// Scatter indices int const* indices_; - // // Static asserts about internal strides // @@ -239,16 +237,15 @@ class PredicatedTileIteratorNormVecSmem { /// Constructor CUTLASS_DEVICE PredicatedTileIteratorNormVecSmem(SharedStorage& shared_storage, - PredicatedTileIteratorParams const& params, - Element* pointer, - TensorCoord extent, - int thread_idx, - //const bool init_shmem, - TensorCoord& threadblock_offset, - int const* indices = nullptr) + PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + // const bool init_shmem, + TensorCoord& threadblock_offset, + int const* indices = nullptr) : params_(params), indices_(indices), shared_storage_(shared_storage) { - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; extent_row_ = extent.row(); @@ -257,7 +254,6 @@ class PredicatedTileIteratorNormVecSmem { thread_start_row_ = thread_offset.row(); thread_start_column_ = thread_offset.column(); - // Initialize predicates CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { @@ -266,9 +262,9 @@ class PredicatedTileIteratorNormVecSmem { } // Null pointer performs no accesses - if (!pointer) { - mask_.clear(); - return; + if (!pointer) { + mask_.clear(); + return; } if (ScatterD && !indices) { mask_.clear(); } @@ -283,16 +279,17 @@ class PredicatedTileIteratorNormVecSmem { } if (threadblock_offset.column() == 0) { - Element* shared_elem_arr = shared_storage_.data(); - uint8_t* first_tile_byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(threadblock_offset.row()) * LongIndex(params_.stride); - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - bool guard = (threadblock_offset.row() + row) < extent_row_; - cutlass::arch::cp_async(shared_elem_arr + row, gmem_ptr + row, guard); - cutlass::arch::cp_async_wait<0>(); - } + Element* shared_elem_arr = shared_storage_.data(); + uint8_t* first_tile_byte_pointer_ = + reinterpret_cast(pointer) + + LongIndex(threadblock_offset.row()) * LongIndex(params_.stride); + auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + bool guard = (threadblock_offset.row() + row) < extent_row_; + cutlass::arch::cp_async(shared_elem_arr + row, gmem_ptr + row, guard); + cutlass::arch::cp_async_wait<0>(); + } } // Initialize internal state counter @@ -310,7 +307,7 @@ class PredicatedTileIteratorNormVecSmem { CUTLASS_DEVICE void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { - AccessType* frag_ptr = reinterpret_cast(&frag); + AccessType* frag_ptr = reinterpret_cast(&frag); Element* shared_elem_arr = shared_storage_.data(); @@ -326,7 +323,7 @@ class PredicatedTileIteratorNormVecSmem { int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; int iter_row = ((row_offset + thread_start_row_) % total_rows); - Element val = shared_elem_arr[iter_row]; + Element val = shared_elem_arr[iter_row]; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerAccess; ++i) { @@ -341,7 +338,6 @@ class PredicatedTileIteratorNormVecSmem { CUTLASS_DEVICE void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } - CUTLASS_DEVICE MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 6ba884f318..c3cb3a458c 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -89,9 +89,9 @@ class PredicatedTileIteratorReducedVec { using OutValT = typename EpilogueOpParams::CGReduceT::AccTypeT; static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - //static int const kElementsPerAccess = 1; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Count::kTile; + // static int const kElementsPerAccess = 1; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); @@ -99,7 +99,6 @@ class PredicatedTileIteratorReducedVec { static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); static_assert(!UseCUDAStore, "UseCUDAStore path is not supported"); - static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * ThreadMap::Count::kTile * ThreadMap::Delta::kRow; @@ -110,7 +109,7 @@ class PredicatedTileIteratorReducedVec { ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * kElementsPerAccess>; // Memory access size - using AccessType = AlignedArray; + using AccessType = AlignedArray; using AccessTypeValT = AlignedArray; // @@ -147,7 +146,7 @@ class PredicatedTileIteratorReducedVec { /// Mask object struct Mask { - //static int const kCount = ThreadMap::Iterations::kColumn; + // static int const kCount = ThreadMap::Iterations::kColumn; static int const kCount = ThreadMap::Iterations::kColumn * kElementsPerAccess; /// Predicate state @@ -192,7 +191,6 @@ class PredicatedTileIteratorReducedVec { // // Data members - // // Methods // @@ -204,19 +202,23 @@ class PredicatedTileIteratorReducedVec { SharedStorage() {} }; - template + template struct select_reduce { /// Performs reduction and stores a reduced output to memory CUTLASS_DEVICE - select_reduce(OutT value, ValT prev_red_val, cg_reduce_op_t reduce_op, - cg_group_t cg_warp_group, OutT& shmem_ptr) + select_reduce(OutT value, + ValT prev_red_val, + cg_reduce_op_t reduce_op, + cg_group_t cg_warp_group, + OutT& shmem_ptr) { if (cg_warp_group.any(reduce_op.isAmin(value, prev_red_val))) { OutT reduced_val = cg::reduce(cg_warp_group, value, reduce_op); - if (cg_warp_group.thread_rank() == 0) { - shmem_ptr = reduced_val; - } + if (cg_warp_group.thread_rank() == 0) { shmem_ptr = reduced_val; } } } }; @@ -227,19 +229,20 @@ class PredicatedTileIteratorReducedVec { using Ty = raft::KeyValuePair; CUTLASS_DEVICE - select_reduce(Ty val_to_red, float prev_red_val, cg_reduce_op_t cg_reduce_op, - cg_group_t cg_warp_group, Ty& shmem_ptr) + select_reduce(Ty val_to_red, + float prev_red_val, + cg_reduce_op_t cg_reduce_op, + cg_group_t cg_warp_group, + Ty& shmem_ptr) { - ValT val = val_to_red.value; + ValT val = val_to_red.value; if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); bool pred = (reduced_val == val); auto subTile = cg::binary_partition(cg_warp_group, pred); if (pred) { - if (subTile.thread_rank() == 0) { - shmem_ptr = val_to_red; - } + if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } } } } @@ -251,19 +254,20 @@ class PredicatedTileIteratorReducedVec { using Ty = raft::KeyValuePair; CUTLASS_DEVICE - select_reduce(Ty val_to_red, double prev_red_val, cg_reduce_op_t cg_reduce_op, - cg_group_t cg_warp_group, Ty& shmem_ptr) + select_reduce(Ty val_to_red, + double prev_red_val, + cg_reduce_op_t cg_reduce_op, + cg_group_t cg_warp_group, + Ty& shmem_ptr) { - ValT val = val_to_red.value; + ValT val = val_to_red.value; if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); bool pred = (reduced_val == val); auto subTile = cg::binary_partition(cg_warp_group, pred); if (pred) { - if (subTile.thread_rank() == 0) { - shmem_ptr = val_to_red; - } + if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } } } } @@ -300,7 +304,7 @@ class PredicatedTileIteratorReducedVec { /// Internal state counter int state_[3]; - //mutable int shared_tile_id; + // mutable int shared_tile_id; /// Scatter indices int const* indices_; @@ -322,7 +326,6 @@ class PredicatedTileIteratorReducedVec { // Methods // public: - // // Methods // @@ -336,10 +339,11 @@ class PredicatedTileIteratorReducedVec { const bool& do_gmem_reduction, TensorCoord threadblock_offset = TensorCoord(), int const* indices = nullptr) - : params_(params), indices_(indices), shared_storage_(shared_storage), + : params_(params), + indices_(indices), + shared_storage_(shared_storage), do_gmem_reduction_(do_gmem_reduction) { - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; extent_row_ = extent.row(); @@ -350,19 +354,20 @@ class PredicatedTileIteratorReducedVec { TensorCoord block_offset = ThreadMap::initial_offset(0) + threadblock_offset; block_start_row_first_tile_ = block_offset.row(); - + // Initialize predicates CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++c) { - int columnPerAccess = (c / kElementsPerAccess); - int columnWithinPerAccess = c % kElementsPerAccess; - mask_.predicates[c] = ((thread_offset.column() + ThreadMap::Delta::kColumn * columnPerAccess + columnWithinPerAccess) < extent.column()); + int columnPerAccess = (c / kElementsPerAccess); + int columnWithinPerAccess = c % kElementsPerAccess; + mask_.predicates[c] = ((thread_offset.column() + ThreadMap::Delta::kColumn * columnPerAccess + + columnWithinPerAccess) < extent.column()); } if (threadblock_offset.column() == 0) { - Element* shared_elem_arr = shared_storage_.data(); + Element* shared_elem_arr = shared_storage_.data(); EpilogueOpParams const& user_params = params_.user_param; - constexpr auto maxVal = std::numeric_limits::max(); + constexpr auto maxVal = std::numeric_limits::max(); for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { user_params.red_op_.init(&shared_elem_arr[row], maxVal); @@ -387,13 +392,14 @@ class PredicatedTileIteratorReducedVec { state_[0] = state_[1] = state_[2] = 0; } - /// Destructor + /// Destructor CUTLASS_DEVICE - ~PredicatedTileIteratorReducedVec() { + ~PredicatedTileIteratorReducedVec() + { if (do_gmem_reduction_) { EpilogueOpParams const& user_params = params_.user_param; - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); Element* shared_elem_arr = shared_storage_.data(); // If this is not optimal grid size perform mutex based gmem reduce. @@ -404,20 +410,18 @@ class PredicatedTileIteratorReducedVec { // acquire mutex lock. unsigned int ns = 8; while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) { - __nanosleep(ns); - if (ns < 256) { - ns *= 2; - } + __nanosleep(ns); + if (ns < 256) { ns *= 2; } } } - __syncthreads(); + __syncthreads(); - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_(0, &gmem_ptr[row], shared_elem_arr[row]); + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + if (block_start_row_first_tile_ + row < extent_row_) { + user_params.red_op_(0, &gmem_ptr[row], shared_elem_arr[row]); + } } - } __threadfence(); __syncthreads(); @@ -449,17 +453,17 @@ class PredicatedTileIteratorReducedVec { { AccessTypeValT* frag_ptr = reinterpret_cast(&frag); - cg::thread_block cta = cg::this_thread_block(); + cg::thread_block cta = cg::this_thread_block(); // tile_width 16 is required if kElementPerAccess > 1 - constexpr int tile_width = (32 / ThreadMap::Delta::kColumn) ? 32 : 16; + constexpr int tile_width = (32 / ThreadMap::Delta::kColumn) ? 32 : 16; cg::thread_block_tile tile32 = cg::tiled_partition(cta); - EpilogueOpParams const& user_params = params_.user_param; + EpilogueOpParams const& user_params = params_.user_param; using cg_reduce_t = decltype(user_params.cg_reduce_op); using tile32_t = decltype(tile32); Element* shared_elem_arr = shared_storage_.data(); - constexpr auto maxVal = std::numeric_limits::max(); + constexpr auto maxVal = std::numeric_limits::max(); CUTLASS_PRAGMA_UNROLL for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { @@ -479,27 +483,29 @@ class PredicatedTileIteratorReducedVec { Element red_val; user_params.red_op_.init(&red_val, maxVal); if (row_guard) { - - const int iter_row = ((row_offset + thread_start_row_) % total_rows); + const int iter_row = ((row_offset + thread_start_row_) % total_rows); const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++column) { - int columnPerAccess = column / kElementsPerAccess; + for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; + ++column) { + int columnPerAccess = column / kElementsPerAccess; int columnWithPerAccess = column % kElementsPerAccess; - bool guard = mask_.predicates[column]; + bool guard = mask_.predicates[column]; if (guard) { - const OutIdxT key_id = thread_start_column_ + ThreadMap::Delta::kColumn * columnPerAccess + columnWithPerAccess; + const OutIdxT key_id = thread_start_column_ + + ThreadMap::Delta::kColumn * columnPerAccess + + columnWithPerAccess; const int frag_col_idx = frag_idx + column; Element this_val; user_params.red_op_.init(&this_val, (*frag_ptr)[frag_col_idx]); - user_params.red_op_.init_key(this_val, key_id ); - user_params.red_op_(key_id , &red_val, this_val); + user_params.red_op_.init_key(this_val, key_id); + user_params.red_op_(key_id, &red_val, this_val); } } select_reduce red_obj( - red_val, prev_red_val, user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); + red_val, prev_red_val, user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); } } } @@ -534,7 +540,6 @@ class PredicatedTileIteratorReducedVec { PredicatedTileIteratorReducedVec& operator++() { ++state_[0]; - //shared_tile_id++; // tile iteration. if (!ScatterD) { byte_pointer_ += params_.advance_row; } diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index df8529cb72..94c6a3a3bd 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -61,22 +61,17 @@ struct MinAndDistanceReduceOpImpl { } DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } - DI void init(KVP* out, DataT maxVal) const - { - out->value = maxVal; - } + DI void init(KVP* out, DataT maxVal) const { out->value = maxVal; } DI void init_key(DataT& out, LabelT idx) const { return; } DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } DI DataT get_value(KVP& out) const { - return out.value;; - } - DI DataT get_value(DataT& out) const - { - return out; + return out.value; + ; } + DI DataT get_value(DataT& out) const { return out; } }; template @@ -268,10 +263,9 @@ struct kvp_cg_min_reduce_op { // functor signature. __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } -__host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } - -__host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } + __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } + __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } }; template distance_op{sqrt}; @@ -370,7 +363,6 @@ void fusedL2NNImpl(OutT* min, min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } - } } // namespace detail diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h index 3c891229f8..cd748b9e6b 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h @@ -290,7 +290,8 @@ class PredicatedTileIteratorNormVec { (void*)&memory_pointer[0], guard); } else { - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn]; + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn]; } } diff --git a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh index 679aab72a9..bd30d5d643 100644 --- a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh @@ -82,23 +82,18 @@ struct FixConnectivitiesRedOp { } DI void init(value_t* out, value_t maxVal) const { *out = maxVal; } - DI void init(KVP* out, value_t maxVal) const - { - out->value = maxVal; - } + DI void init(KVP* out, value_t maxVal) const { out->value = maxVal; } DI void init_key(value_t& out, value_idx idx) const { return; } DI void init_key(KVP& out, value_idx idx) const { out.key = idx; } DI value_t get_value(KVP& out) const { - return out.value;; + return out.value; + ; } - DI value_t get_value(value_t& out) const - { - return out; - } + DI value_t get_value(value_t& out) const { return out; } }; /** diff --git a/cpp/include/raft/util/cutlass_utils.cuh b/cpp/include/raft/util/cutlass_utils.cuh index 3456c0c3e5..b60ca644a4 100644 --- a/cpp/include/raft/util/cutlass_utils.cuh +++ b/cpp/include/raft/util/cutlass_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index c8d4f91ec0..9b849bc261 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -245,14 +245,8 @@ if(BUILD_TESTS) ) ConfigureTest( - NAME - SPARSE_DIST_TEST - PATH - test/sparse/dist_coo_spmv.cu - test/sparse/distance.cu - test/sparse/gram.cu - OPTIONAL - LIB + NAME SPARSE_DIST_TEST PATH test/sparse/dist_coo_spmv.cu test/sparse/distance.cu + test/sparse/gram.cu OPTIONAL LIB ) ConfigureTest( From c116d234ca1f206d2ff0393f8631d6507b9e1ae4 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 28 Apr 2023 04:46:43 -0700 Subject: [PATCH 28/48] fix copyright and formatting issues in couple cutlass source files --- .../custom_epilogue_with_broadcast.h | 50 ++++++++++++------- .../fused_distance_nn/persistent_gemm.h | 18 ++++++- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index d90d75a4b4..758f330c16 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -1,5 +1,21 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -47,26 +63,26 @@ #include #endif -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/functional.h" -#include "cutlass/layout/tensor.h" -#include "cutlass/layout/vector.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_coord.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include -#include "cutlass/gemm/gemm.h" +#include -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include +#include -#include "cutlass/epilogue/threadblock/epilogue_base.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include +#include -#include "cutlass/numeric_types.h" +#include ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index cca25c0cdd..1bf38bed92 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -1,5 +1,21 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without From 527c89d7c1cf36cb3b4d2b7574a3aed9f92a572e Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 28 Apr 2023 05:49:00 -0700 Subject: [PATCH 29/48] restrict cutlass kernel to sm 80+ using _cuda_arch_ --- .../raft/distance/detail/fused_distance_nn/persistent_gemm.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index 1bf38bed92..107d6a4026 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -369,6 +369,7 @@ struct FusedDistanceNNPersistent { CUTLASS_DEVICE void operator()(Params const& params, SharedStorage& shared_storage) { +#if __CUDA_ARCH__ >= 800 // // These types shadow the type-level definitions and support the ability to implement // a 'transposed' GEMM that computes the transposed problems. @@ -523,6 +524,7 @@ struct FusedDistanceNNPersistent { problem_size.mn(), threadblock_offset); } +#endif } }; From 90c2c3928cbca24a14a5c7243d2ca796832eb048 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 28 Apr 2023 07:29:18 -0700 Subject: [PATCH 30/48] add ignore -Wtautological-compare to not report warnings as error in cutlass header uint128.h --- .../raft/distance/detail/fused_distance_nn/cutlass_base.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh index 2b742f9ee6..29afa536ae 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh @@ -18,6 +18,7 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wtautological-compare" // We define CUTLASS_NAMESPACE in case // RAFT cmake is not used From f5a493b41862bb5d9fbeede27094b400464924d2 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 28 Apr 2023 08:21:56 -0700 Subject: [PATCH 31/48] add ignore warning pragma -Wtautological-compare to pairwise_distance cutlass base, make use of RAFT_CUTLASS_TRY instead of its local CUTLASS_CHECK macro --- .../detail/pairwise_distance_cutlass_base.cuh | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index efcd5d9389..ccb3bd46bf 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -18,6 +18,7 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wtautological-compare" // We define CUTLASS_NAMESPACE in case // RAFT cmake is not used @@ -38,20 +39,11 @@ #include #include +#include #include "./pairwise_distance_epilogue_elementwise.h" #include "./pairwise_distance_gemm.h" -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - if (error != cutlass::Status::kSuccess) { \ - std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ - << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } - namespace raft { namespace distance { namespace detail { @@ -164,14 +156,13 @@ std::enable_if_t::value> cutlassDistanceKernel(const Da // Instantiate CUTLASS kernel depending on templates cutlassDist cutlassDist_op; // Check the problem size is supported or not - cutlass::Status status = cutlassDist_op.can_implement(arguments); - CUTLASS_CHECK(status); + RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); + // Initialize CUTLASS kernel with arguments and workspace pointer - status = cutlassDist_op.initialize(arguments, workspace.data(), stream); - CUTLASS_CHECK(status); + RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); + // Launch initialized CUTLASS kernel - status = cutlassDist_op(); - CUTLASS_CHECK(status); + RAFT_CUTLASS_TRY(cutlassDist_op()); } }; // namespace detail From ab219fcdd369527962ac3048398ea20eccd213d2 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 4 May 2023 08:26:37 -0700 Subject: [PATCH 32/48] fix the failure in cluster_test due to incorrect row_id passed to reduce_op functor, do not launch redundant threadblocks for small input size cases, code cleanup, comments --- .../detail/fused_distance_nn/cutlass_base.cuh | 17 ++++--- .../epilogue_elementwise.cuh | 6 +-- .../predicated_tile_iterator_reduced_vec.h | 50 +++++++++++-------- .../neighbors/detail/connect_components.cuh | 12 ++--- 4 files changed, 49 insertions(+), 36 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh index 29afa536ae..a7d9f49335 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh @@ -37,10 +37,10 @@ #include #include -#include -#include -#include -#include +#include // FusedDistanceNNEpilogueElementwise +#include // FusedDistanceNNGemm +#include // getMultiProcessorCount +#include // RAFT_CUTLASS_TRY namespace raft { namespace distance { @@ -115,9 +115,14 @@ void cutlassFusedDistanceNN(const DataT* x, int num_blocks_per_sm = fusedDistanceNN::maximum_active_blocks(); int num_sms = raft::getMultiProcessorCount(); - int num_blocks = num_blocks_per_sm * num_sms; + int full_wave = num_blocks_per_sm * num_sms; constexpr int mmaShapeM = fusedDistanceNNKernel::Mma::Shape::kM; - auto thread_blocks = std::max(num_blocks, int((problem_size.m() - 1 + mmaShapeM) / mmaShapeM)); + constexpr int mmaShapeN = fusedDistanceNNKernel::Mma::Shape::kN; + int columnTiles = (problem_size.n() - 1 + mmaShapeN) / mmaShapeN; + int rowTiles = (problem_size.m() - 1 + mmaShapeM) / mmaShapeM; + + int thread_blocks = + rowTiles < full_wave ? (columnTiles < full_wave ? columnTiles : full_wave) : rowTiles; typename fusedDistanceNN::Arguments arguments{ problem_size, diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh index 4058406bc1..7f914bb30c 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -73,8 +73,7 @@ class FusedDistanceNNEpilogueElementwise { using FragmentC = Array; using FragmentZ = Array; using OutValT = typename CGReduceOp::AccTypeT; - // using FragmentT = Array; - using FragmentT = Array; + using FragmentT = Array; using FragmentOutput = FragmentZ; @@ -164,8 +163,7 @@ class FusedDistanceNNEpilogueElementwise { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kElementsPerAccess; ++i) { ElementCompute res_Z = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); - // red_op.init(&frag_T[i], res_Z); - frag_T[i] = res_Z; + frag_T[i] = res_Z; } } diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index c3cb3a458c..014fafd078 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -208,7 +208,7 @@ class PredicatedTileIteratorReducedVec { typename ValT, typename OutT> struct select_reduce { - /// Performs reduction and stores a reduced output to memory + /// Performs warp level reduction and stores a reduced output to memory CUTLASS_DEVICE select_reduce(OutT value, ValT prev_red_val, @@ -227,7 +227,7 @@ class PredicatedTileIteratorReducedVec { struct select_reduce> { using ValT = float; using Ty = raft::KeyValuePair; - + /// Performs warp level reduction of key value pair and stores a reduced output to memory CUTLASS_DEVICE select_reduce(Ty val_to_red, float prev_red_val, @@ -252,7 +252,7 @@ class PredicatedTileIteratorReducedVec { struct select_reduce> { using ValT = double; using Ty = raft::KeyValuePair; - + /// Performs warp level reduction of key value pair and stores a reduced output to memory CUTLASS_DEVICE select_reduce(Ty val_to_red, double prev_red_val, @@ -399,9 +399,6 @@ class PredicatedTileIteratorReducedVec { if (do_gmem_reduction_) { EpilogueOpParams const& user_params = params_.user_param; - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - Element* shared_elem_arr = shared_storage_.data(); - // If this is not optimal grid size perform mutex based gmem reduce. if ((gridDim.x != ((extent_row_ - 1 + Shape::kRow) / Shape::kRow))) { const auto mutex_id = (block_start_row_first_tile_ / total_rows); @@ -417,25 +414,34 @@ class PredicatedTileIteratorReducedVec { __syncthreads(); - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_(0, &gmem_ptr[row], shared_elem_arr[row]); - } - } + store_output_shared_to_global(); __threadfence(); __syncthreads(); if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { // release mutex lock. - atomicCAS(user_params.mutexes_ + mutex_id, 1, 0); + // atomicCAS(user_params.mutexes_ + mutex_id, 1, 0); + atomicExch(user_params.mutexes_ + mutex_id, 0); } } else { __syncthreads(); - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - gmem_ptr[row] = shared_elem_arr[row]; - } - } + store_output_shared_to_global(); + } + } + } + + /// store the final shared mem output to global mem + CUTLASS_DEVICE + void store_output_shared_to_global() + { + EpilogueOpParams const& user_params = params_.user_param; + Element* shared_elem_arr = shared_storage_.data(); + auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + OutIdxT g_row_id = block_start_row_first_tile_ + row; + if (g_row_id < extent_row_) { + user_params.red_op_(g_row_id, gmem_ptr + row, shared_elem_arr[row]); } } } @@ -477,13 +483,15 @@ class PredicatedTileIteratorReducedVec { int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + const OutIdxT row_id = row_offset + thread_start_row_; + bool row_guard = (row_id < extent_row_); const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn * kElementsPerAccess; Element red_val; user_params.red_op_.init(&red_val, maxVal); + if (row_guard) { - const int iter_row = ((row_offset + thread_start_row_) % total_rows); + const int iter_row = (row_id % total_rows); const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); CUTLASS_PRAGMA_UNROLL @@ -501,9 +509,11 @@ class PredicatedTileIteratorReducedVec { Element this_val; user_params.red_op_.init(&this_val, (*frag_ptr)[frag_col_idx]); user_params.red_op_.init_key(this_val, key_id); - user_params.red_op_(key_id, &red_val, this_val); + user_params.red_op_(row_id, &red_val, this_val); } } + // select_reduce doesn't need to use `red_op_` as at the warp level we use cg_reduce_op, + // this satisfies the requirement of mst/single linkage of checking colors buffer. select_reduce red_obj( red_val, prev_red_val, user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); } diff --git a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh index bd30d5d643..0215c57ae1 100644 --- a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh @@ -82,16 +82,16 @@ struct FixConnectivitiesRedOp { } DI void init(value_t* out, value_t maxVal) const { *out = maxVal; } - DI void init(KVP* out, value_t maxVal) const { out->value = maxVal; } + DI void init(KVP* out, value_t maxVal) const + { + out->key = -1; + out->value = maxVal; + } DI void init_key(value_t& out, value_idx idx) const { return; } DI void init_key(KVP& out, value_idx idx) const { out.key = idx; } - DI value_t get_value(KVP& out) const - { - return out.value; - ; - } + DI value_t get_value(KVP& out) const { return out.value; } DI value_t get_value(value_t& out) const { return out; } }; From 5ab9f59079191e8bfb6697fa77d92b8fdb28309c Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 4 May 2023 08:56:14 -0700 Subject: [PATCH 33/48] remove redundant code in custom_epilogue_with_broadcast.h, and add comments about this source file --- .../custom_epilogue_with_broadcast.h | 115 +----------------- 1 file changed, 5 insertions(+), 110 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index 758f330c16..ab028f95c5 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -51,6 +51,11 @@ The epilogue rearranges the result of a matrix product through shared memory to match canonical tensor layouts in global memory. Epilogues support conversion and reduction operations. +This file contains a customized version of EpilogueWithBroadcast from CUTLASS 2.9.1 +(https://github.com/NVIDIA/cutlass/blob/v2.9.1/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h) + +Changes: +- customized the compute_source_needed_() and apply_output_operator_() to suit the needs of per row reduction */ #pragma once @@ -512,100 +517,6 @@ class EpilogueWithBroadcastCustom : public EpilogueBase>::push(iter, - accum_fragment_iterator, - this->warp_tile_iterator_); - - __syncthreads(); - - // - // Load fragments from shared memory - // - - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { - - - typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; - - shared_load_iterator_.load(aligned_accum_fragment[0]); - - if (p < Base::kFragmentsPerIteration - 1) { - shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); - } - else if (kPartitionsK > 1) { - - plus add_fragments; - - CUTLASS_PRAGMA_UNROLL - for ( int i = 1; i < kPartitionsK; ++i) { - shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); - shared_load_iterator_.load(aligned_accum_fragment[i]); - aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); - } - - shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); - } - - // - // Apply output operation - // - - typename OutputTileIterator::Fragment frag_Z; - typename TensorTileIterator::Fragment frag_T; - - apply_output_operator_source_not_needed_( - frag_Z, - frag_T, - output_op, - aligned_accum_fragment[0], - broadcast_fragment); - - // - // Conditionally store fragments - // - - if (OutputOp::kStoreZ) { - destination_iterator.store(frag_Z); - ++destination_iterator; - } - - if (OutputOp::kStoreT) { - tensor_iterator.store(frag_T); - ++tensor_iterator; - } - } - - if (Base::kFragmentsPerIteration > 1) { - shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); - } - } -#endif } template @@ -681,23 +592,7 @@ class EpilogueWithBroadcastCustom : public EpilogueBase 1 - perform a reduction amongst the k-slices - if (kPartitionsK > 1) - { - plus add_fragments; - const int tile_row_offset = Base::SharedStorage::StorageShape::kRow / PartitionsK; - - CUTLASS_PRAGMA_UNROLL - for ( int i = 1; i < kPartitionsK; ++i) { - shared_load_iterator_.add_tile_offset({tile_row_offset , 0}); - shared_load_iterator_.load(aligned_accum_fragment[i]); - aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); - } - shared_load_iterator_.add_tile_offset({-1 * (kPartitionsK-1) * tile_row_offset, 0}); - } -#endif // // Apply output operation // From 1fce36a63963561d70f801ab0dfee6cd84430443 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 5 May 2023 02:02:06 -0700 Subject: [PATCH 34/48] remove the redundant header inclusion in fusedl2knn tests which was previously needed --- cpp/test/neighbors/fused_l2_knn.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index be0c26740a..9fbccf681d 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -18,7 +18,6 @@ #include "./knn_utils.cuh" #include -#include #include #include #include From b072d8097d3530c733eac6aff6f5067639e48fa5 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 5 May 2023 02:40:04 -0700 Subject: [PATCH 35/48] remove commented code and fix formating --- .../fused_distance_nn/custom_epilogue_with_broadcast.h | 3 ++- cpp/include/raft/distance/detail/fused_distance_nn/gemm.h | 2 -- .../predicated_tile_iterator_normvec_smem.h | 6 +----- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index ab028f95c5..1171d25727 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -55,7 +55,8 @@ This file contains a customized version of EpilogueWithBroadcast from CUTLASS 2. (https://github.com/NVIDIA/cutlass/blob/v2.9.1/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h) Changes: -- customized the compute_source_needed_() and apply_output_operator_() to suit the needs of per row reduction +- customized the compute_source_needed_() and apply_output_operator_() to suit the needs of per row +reduction */ #pragma once diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h index feb078fcc0..48ac0b232c 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h @@ -47,8 +47,6 @@ template < typename ElementC_, /// Element type for internal accumulation typename ElementAccumulator, - /// Element type for final output - // typename ElementOutT, /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' typename EpilogueOutputOp, /// Number of stages used in the pipelined mainloop diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h index 86d5e688a8..4a22a4a7fc 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h @@ -81,9 +81,6 @@ class PredicatedTileIteratorNormVecSmem { static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; static int const kThreads = ThreadMap::kThreads; static int const kIterations = ThreadMap::Count::kTile; - // static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * - // ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * - // kIterations; static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * @@ -190,7 +187,7 @@ class PredicatedTileIteratorNormVecSmem { /// Byte-level pointer uint8_t* byte_pointer_; - // uint8_t* first_tile_byte_pointer_; + /// Array of boolean values to contain steady-state predicates Mask mask_; @@ -241,7 +238,6 @@ class PredicatedTileIteratorNormVecSmem { Element* pointer, TensorCoord extent, int thread_idx, - // const bool init_shmem, TensorCoord& threadblock_offset, int const* indices = nullptr) : params_(params), indices_(indices), shared_storage_(shared_storage) From a7c73037e0adde1996ee6926c21195dd7b44a73b Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 5 May 2023 03:06:03 -0700 Subject: [PATCH 36/48] add larger input test cases to test fusedL2nn all code paths --- .../fused_distance_nn/predicated_tile_iterator_reduced_vec.h | 1 - cpp/test/distance/fused_l2_nn.cu | 5 ++++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 014fafd078..bb82fa8eb8 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -420,7 +420,6 @@ class PredicatedTileIteratorReducedVec { __syncthreads(); if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { // release mutex lock. - // atomicCAS(user_params.mutexes_ + mutex_id, 1, 0); atomicExch(user_params.mutexes_ + mutex_id, 0); } } else { diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index c4ccd55f69..60e977f087 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -275,6 +275,8 @@ const std::vector> inputsf = { {0.001f, 128, 128, 65, 1234ULL}, {0.001f, 64, 128, 129, 1234ULL}, {0.006f, 1805, 134, 2, 1234ULL}, + {0.006f, 8192, 1024, 64, 1234ULL}, + {0.006f, 8192, 1025, 64, 1234ULL}, // Repeat with smaller values of k {0.006f, 32, 32, 1, 1234ULL}, @@ -304,6 +306,7 @@ const std::vector> inputsf = { {0.001f, 128, 128, 23, 1234ULL}, {0.00001, 64, 128, 24, 1234ULL}, {0.001f, 1805, 134, 25, 1234ULL}, + {0.006f, 8192, 1024, 25, 1234ULL}, }; typedef FusedL2NNTest FusedL2NNTestF_Sq; TEST_P(FusedL2NNTestF_Sq, Result) @@ -338,7 +341,7 @@ const std::vector> inputsd = { {0.00001, 128, 32, 33, 1234ULL}, {0.00001, 128, 64, 33, 1234ULL}, {0.00001, 128, 128, 65, 1234ULL}, {0.00001, 64, 128, 129, 1234ULL}, - {0.00001, 1805, 134, 2, 1234ULL}, + {0.00001, 1805, 134, 2, 1234ULL}, {0.00001, 8192, 1024, 25, 1234ULL}, }; typedef FusedL2NNTest FusedL2NNTestD_Sq; TEST_P(FusedL2NNTestD_Sq, Result) From d741c4ad3858f516f08f546264f4ef89282c0092 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 8 May 2023 11:20:12 -0700 Subject: [PATCH 37/48] fix launch config for small input sizes, fix atomicCAS optimal path selection, add comments on register spills tile shape, add test case for veclen=2 --- .../detail/fused_distance_nn/cutlass_base.cuh | 4 +- .../distance/detail/fused_distance_nn/gemm.h | 20 +++++++--- .../predicated_tile_iterator_reduced_vec.h | 37 ++++++++----------- cpp/test/distance/fused_l2_nn.cu | 1 + 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh index a7d9f49335..a1cf1a9b17 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh @@ -120,9 +120,9 @@ void cutlassFusedDistanceNN(const DataT* x, constexpr int mmaShapeN = fusedDistanceNNKernel::Mma::Shape::kN; int columnTiles = (problem_size.n() - 1 + mmaShapeN) / mmaShapeN; int rowTiles = (problem_size.m() - 1 + mmaShapeM) / mmaShapeM; - + int totalTiles = columnTiles * rowTiles; int thread_blocks = - rowTiles < full_wave ? (columnTiles < full_wave ? columnTiles : full_wave) : rowTiles; + rowTiles < full_wave ? (totalTiles < full_wave ? totalTiles : full_wave) : rowTiles; typename fusedDistanceNN::Arguments arguments{ problem_size, diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h index 48ac0b232c..e64db73a07 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h @@ -57,19 +57,27 @@ struct FusedDistanceNNGemm { // This struct is specialized for fp32/3xTF32 /// Threadblock-level tile size (concept: GemmShape) - // <- threadblock tile M = 32, N = 64, K = 16 + // <- threadblock tile M = 32, N = 256, K = 16 + // this is more performant but note than for veclen = 1 + // this shape has register spills using ThreadblockShape = - cutlass::gemm::GemmShape<32, 256, 16>; // this is more performant for grouped GEMM - // using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; // this shape has high - // occupancy but less perf + cutlass::gemm::GemmShape<32, 256, 16>; + + + // <- threadblock tile M = 32, N = 128, K = 16 + // this shape has high occupancy but less perf + // this is less performant but this shape has *no* register spills + // for any veclens(1, 2, 4) + //using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute // <- warp tile M = 64, N = 64, K = 16 using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // this is more performant for grouped GEMM - // using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // this shape has high occupancy but - // less perf + + // this shape has high occupancy but less perf used for 32x128x16 + //using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index bb82fa8eb8..93cb83e1fe 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -89,7 +89,6 @@ class PredicatedTileIteratorReducedVec { using OutValT = typename EpilogueOpParams::CGReduceT::AccTypeT; static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - // static int const kElementsPerAccess = 1; static int const kThreads = ThreadMap::kThreads; static int const kIterations = ThreadMap::Count::kTile; @@ -398,9 +397,11 @@ class PredicatedTileIteratorReducedVec { { if (do_gmem_reduction_) { EpilogueOpParams const& user_params = params_.user_param; + auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + Element* shared_elem_arr = shared_storage_.data(); // If this is not optimal grid size perform mutex based gmem reduce. - if ((gridDim.x != ((extent_row_ - 1 + Shape::kRow) / Shape::kRow))) { + if ((gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows))) { const auto mutex_id = (block_start_row_first_tile_ / total_rows); // single lock per block for multiple rows if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { @@ -414,8 +415,11 @@ class PredicatedTileIteratorReducedVec { __syncthreads(); - store_output_shared_to_global(); - + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + if (block_start_row_first_tile_ + row < extent_row_) { + user_params.red_op_(block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); + } + } __threadfence(); __syncthreads(); if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { @@ -424,23 +428,12 @@ class PredicatedTileIteratorReducedVec { } } else { __syncthreads(); - store_output_shared_to_global(); - } - } - } - - /// store the final shared mem output to global mem - CUTLASS_DEVICE - void store_output_shared_to_global() - { - EpilogueOpParams const& user_params = params_.user_param; - Element* shared_elem_arr = shared_storage_.data(); - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - OutIdxT g_row_id = block_start_row_first_tile_ + row; - if (g_row_id < extent_row_) { - user_params.red_op_(g_row_id, gmem_ptr + row, shared_elem_arr[row]); + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + if (block_start_row_first_tile_ + row < extent_row_) { + //user_params.red_op_(block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); + gmem_ptr[row] = shared_elem_arr[row]; + } + } } } } @@ -482,7 +475,7 @@ class PredicatedTileIteratorReducedVec { int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; - const OutIdxT row_id = row_offset + thread_start_row_; + const OutIdxT row_id = row_offset + thread_start_row_; bool row_guard = (row_id < extent_row_); const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn * kElementsPerAccess; diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 60e977f087..e31d59cfd4 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -307,6 +307,7 @@ const std::vector> inputsf = { {0.00001, 64, 128, 24, 1234ULL}, {0.001f, 1805, 134, 25, 1234ULL}, {0.006f, 8192, 1024, 25, 1234ULL}, + {0.006f, 8192, 1024, 66, 1234ULL}, }; typedef FusedL2NNTest FusedL2NNTestF_Sq; TEST_P(FusedL2NNTestF_Sq, Result) From 7e4b298f4cc726393552167e59a22eaa83e3b3be Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 10 May 2023 02:15:42 -0700 Subject: [PATCH 38/48] fix formatting issues --- .../raft/distance/detail/fused_distance_nn/gemm.h | 10 ++++------ .../predicated_tile_iterator_reduced_vec.h | 14 +++++++------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h index e64db73a07..22335bac72 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h @@ -60,24 +60,22 @@ struct FusedDistanceNNGemm { // <- threadblock tile M = 32, N = 256, K = 16 // this is more performant but note than for veclen = 1 // this shape has register spills - using ThreadblockShape = - cutlass::gemm::GemmShape<32, 256, 16>; - + using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 16>; // <- threadblock tile M = 32, N = 128, K = 16 // this shape has high occupancy but less perf // this is less performant but this shape has *no* register spills // for any veclens(1, 2, 4) - //using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; + // using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute // <- warp tile M = 64, N = 64, K = 16 using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // this is more performant for grouped GEMM - + // this shape has high occupancy but less perf used for 32x128x16 - //using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + // using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 93cb83e1fe..8194eeccd5 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -89,8 +89,8 @@ class PredicatedTileIteratorReducedVec { using OutValT = typename EpilogueOpParams::CGReduceT::AccTypeT; static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Count::kTile; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); @@ -397,8 +397,8 @@ class PredicatedTileIteratorReducedVec { { if (do_gmem_reduction_) { EpilogueOpParams const& user_params = params_.user_param; - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - Element* shared_elem_arr = shared_storage_.data(); + auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + Element* shared_elem_arr = shared_storage_.data(); // If this is not optimal grid size perform mutex based gmem reduce. if ((gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows))) { @@ -417,7 +417,8 @@ class PredicatedTileIteratorReducedVec { for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_(block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); + user_params.red_op_( + block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); } } __threadfence(); @@ -430,7 +431,6 @@ class PredicatedTileIteratorReducedVec { __syncthreads(); for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { if (block_start_row_first_tile_ + row < extent_row_) { - //user_params.red_op_(block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); gmem_ptr[row] = shared_elem_arr[row]; } } @@ -475,7 +475,7 @@ class PredicatedTileIteratorReducedVec { int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; - const OutIdxT row_id = row_offset + thread_start_row_; + const OutIdxT row_id = row_offset + thread_start_row_; bool row_guard = (row_id < extent_row_); const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn * kElementsPerAccess; From 7f1d30dc31afcbcbb0e350601ab40f550f7d8617 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 10 May 2023 02:35:12 -0700 Subject: [PATCH 39/48] move raft copyright below cutlass's and fix start year to be 2023 --- .../custom_epilogue_with_broadcast.h | 33 ++++++++++--------- .../detail/fused_distance_nn/epilogue.cuh | 32 +++++++++++++++++- .../epilogue_elementwise.cuh | 32 +++++++++++++++++- .../distance/detail/fused_distance_nn/gemm.h | 33 +++++++++++++++++-- .../fused_distance_nn/persistent_gemm.h | 31 +++++++++-------- .../predicated_tile_iterator_normvec_smem.h | 32 +++++++++++++++++- .../predicated_tile_iterator_reduced_vec.h | 32 +++++++++++++++++- 7 files changed, 187 insertions(+), 38 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index 1171d25727..585cd14cd5 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -1,19 +1,3 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - /*************************************************************************************************** * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause @@ -44,6 +28,23 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + /*! \file \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh index 7feaea1f02..8a0bea3469 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh @@ -1,5 +1,35 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh index 7f914bb30c..a21f3d60e0 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -1,5 +1,35 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h index 22335bac72..84bdd9f087 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h @@ -1,5 +1,35 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +52,6 @@ #include #include -// #include #include #include diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index 107d6a4026..42cb9278b4 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -1,19 +1,3 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - /*************************************************************************************************** * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause @@ -44,6 +28,21 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ /*! \file \brief Problem visitor for grouped GEMMs diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h index 4a22a4a7fc..b223043b93 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h @@ -1,5 +1,35 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 8194eeccd5..5ce7edf618 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -1,5 +1,35 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 514bd1e327ee5ed9bd8a6c3bb1ad2b4410947011 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 11 May 2023 10:47:46 -0700 Subject: [PATCH 40/48] add specialization for veclen=1 with 32x128x16 having no reg spills, fix cutlass_utils.h comments --- .../distance/detail/fused_distance_nn/gemm.h | 123 +++++++++++++++++- cpp/include/raft/util/cutlass_utils.cuh | 11 +- 2 files changed, 126 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h index 84bdd9f087..bdb18c4af2 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h @@ -62,7 +62,12 @@ namespace gemm { namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// - +/* +* This configuration is used for float inputs with veclen(kAlignmentA/B) = 2 or 4, +* ideal threadblock tile shape is 32x256x16 for such cases as there is no +* registers spills for it. +* +*/ template < /// Element type for A matrix operand typename ElementA_, @@ -87,7 +92,7 @@ struct FusedDistanceNNGemm { /// Threadblock-level tile size (concept: GemmShape) // <- threadblock tile M = 32, N = 256, K = 16 - // this is more performant but note than for veclen = 1 + // this is more performant but note that for veclen = 1 // this shape has register spills using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 16>; @@ -180,6 +185,120 @@ struct FusedDistanceNNGemm { GroupScheduleMode::kDeviceOnly>; }; +/* +* This configuration is used for float inputs with veclen(kAlignmentA/B) = 1, +* ideal threadblock tile shape is 32x128x16 for such cases as there is no +* registers spills for it. +* +*/ +template < + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct FusedDistanceNNGemm { + // This struct is specialized for fp32/3xTF32 + using ElementA_ = float; + using ElementB_ = float; + + /// Threadblock-level tile size (concept: GemmShape) + // <- threadblock tile M = 32, N = 128, K = 16 + // this shape has high occupancy and no register spills for veclen = 1. + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; + + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + // <- warp tile M = 32, N = 32, K = 16 + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + // <- MMA Op tile M = 16, N = 8, K = 4 + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastF32; + // using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementAccumulator, + typename EpilogueOutputOp::ElementT, + ElementAccumulator, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = FusedDistanceNNPersistent; +}; + template < /// Layout type for A matrix operand int kAlignmentA, diff --git a/cpp/include/raft/util/cutlass_utils.cuh b/cpp/include/raft/util/cutlass_utils.cuh index b60ca644a4..da402c9427 100644 --- a/cpp/include/raft/util/cutlass_utils.cuh +++ b/cpp/include/raft/util/cutlass_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ namespace raft { /** - * @brief Exception thrown when a CUDA error is encountered. + * @brief Exception thrown when a CUTLASS error is encountered. */ struct cutlass_error : public raft::exception { explicit cutlass_error(char const* const message) : raft::exception(message) {} @@ -32,11 +32,10 @@ struct cutlass_error : public raft::exception { } // namespace raft /** - * @brief Error checking macro for CUDA runtime API functions. + * @brief Error checking macro for CUTLASS functions. * - * Invokes a CUDA runtime API function call, if the call does not return - * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an - * exception detailing the CUDA error that occurred + * Invokes a CUTLASS function call, if the call does not return cutlass::Status::kSuccess, + * throws an exception detailing the CUTLASS error that occurred. * */ #define RAFT_CUTLASS_TRY(call) \ From cba6bbc44a98bbd0a55da7c0b828c42efb7329e7 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 11 May 2023 11:31:56 -0700 Subject: [PATCH 41/48] move smem init code to their respective tile iterators instead of having 2 copies of the same code --- .../fused_distance_nn/persistent_gemm.h | 36 +++---------------- .../predicated_tile_iterator_normvec_smem.h | 26 ++++++++------ .../predicated_tile_iterator_reduced_vec.h | 17 +++++---- 3 files changed, 31 insertions(+), 48 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index 42cb9278b4..e48c57be99 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -379,9 +379,6 @@ struct FusedDistanceNNPersistent { using LayoutB = typename Mma::IteratorB::Layout; using ElementC = typename Epilogue::OutputTileIterator::Element; using LayoutC = typename Epilogue::OutputTileIterator::Layout; - using ElementOut = typename Epilogue::TensorTileIterator::Element; - using LongIndexOut = typename Epilogue::TensorTileIterator::LongIndex; - using OutValTy = typename Epilogue::TensorTileIterator::OutValT; const GemmCoord& problem_size = params.problem_size; const uint32_t problem_chunk = @@ -390,35 +387,12 @@ struct FusedDistanceNNPersistent { const auto grid_shape_ = grid_shape(problem_size); typename LayoutB::Index column = ((blockIdx.x * problem_chunk) % grid_shape_.column()) * Mma::Shape::kN; - { - ElementOut* shared_elem_arr_ = shared_storage.reduced_store.data(); - constexpr auto maxVal_ = std::numeric_limits::max(); - - if (column) { - for (int row = threadIdx.x; row < Mma::Shape::kM; row += blockDim.x) { - params.output_op.red_op_.init(&shared_elem_arr_[row], maxVal_); - } - } - } - { - ElementC* shared_elem_arr = shared_storage.rownorm_store.data(); - if (column) { - typename LayoutB::Index row = - ((blockIdx.x * problem_chunk) / grid_shape_.column()) * Mma::Shape::kM; - - uint8_t* first_tile_byte_pointer_ = - reinterpret_cast(params.ptr_C) + - typename LayoutB::LongIndex(row) * typename LayoutB::LongIndex(sizeof(ElementC)); - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - - for (int row_local = threadIdx.x; row_local < Mma::Shape::kM; row_local += blockDim.x) { - bool guard = (row + row_local) < problem_size.m(); - cutlass::arch::cp_async( - shared_elem_arr + row_local, gmem_ptr + row_local, guard); - cutlass::arch::cp_async_wait<0>(); - } - } + typename LayoutB::Index row = + ((blockIdx.x * problem_chunk) / grid_shape_.column()) * Mma::Shape::kM; + if (column) { + shared_storage.reduced_store.initSmem(params.output_op); + shared_storage.rownorm_store.initSmem(params.ptr_C, problem_size.m(), row, sizeof(ElementC)); } // Outer 'persistent' loop to iterate over tiles diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h index b223043b93..244e723ae2 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h @@ -205,6 +205,20 @@ class PredicatedTileIteratorNormVecSmem { Element* data() { return storage.data(); } SharedStorage() {} + + CUTLASS_DEVICE + void initSmem(void *pointer, const Index &num_rows, const Index &tb_row_offset, const LongIndex &stride) { + Element* shared_elem_arr = data(); + uint8_t* first_tile_byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(tb_row_offset) * LongIndex(stride); + const auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + bool guard = (tb_row_offset + row) < num_rows; + cutlass::arch::cp_async(shared_elem_arr + row, gmem_ptr + row, guard); + cutlass::arch::cp_async_wait<0>(); + } + } }; private: @@ -305,17 +319,7 @@ class PredicatedTileIteratorNormVecSmem { } if (threadblock_offset.column() == 0) { - Element* shared_elem_arr = shared_storage_.data(); - uint8_t* first_tile_byte_pointer_ = - reinterpret_cast(pointer) + - LongIndex(threadblock_offset.row()) * LongIndex(params_.stride); - auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - bool guard = (threadblock_offset.row() + row) < extent_row_; - cutlass::arch::cp_async(shared_elem_arr + row, gmem_ptr + row, guard); - cutlass::arch::cp_async_wait<0>(); - } + shared_storage_.initSmem(pointer, extent_row_, threadblock_offset.row(), params_.stride); } // Initialize internal state counter diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 5ce7edf618..d94489e759 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -229,6 +229,16 @@ class PredicatedTileIteratorReducedVec { Element* data() { return storage.data(); } SharedStorage() {} + + CUTLASS_DEVICE + void initSmem(EpilogueOpParams const& user_params) { + Element* shared_elem_arr = data(); + constexpr auto maxVal = std::numeric_limits::max(); + + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + user_params.red_op_.init(&shared_elem_arr[row], maxVal); + } + } }; template ::max(); - - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - user_params.red_op_.init(&shared_elem_arr[row], maxVal); - } + shared_storage_.initSmem(user_params); } // Null pointer performs no accesses From 71b91e93c69e21821d8c8cbf75341da272480a22 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 12 May 2023 03:04:05 -0700 Subject: [PATCH 42/48] combine the optimal path and non-optimal path gmem writes --- .../predicated_tile_iterator_reduced_vec.h | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index d94489e759..70d4d80c49 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -434,10 +434,10 @@ class PredicatedTileIteratorReducedVec { EpilogueOpParams const& user_params = params_.user_param; auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); Element* shared_elem_arr = shared_storage_.data(); - + const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); + bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); // If this is not optimal grid size perform mutex based gmem reduce. - if ((gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows))) { - const auto mutex_id = (block_start_row_first_tile_ / total_rows); + if (useGmemMutex) { // single lock per block for multiple rows if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { // acquire mutex lock. @@ -447,28 +447,23 @@ class PredicatedTileIteratorReducedVec { if (ns < 256) { ns *= 2; } } } + } - __syncthreads(); - - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - user_params.red_op_( - block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); - } + __syncthreads(); + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + if (block_start_row_first_tile_ + row < extent_row_) { + user_params.red_op_( + block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); } + } + + if (useGmemMutex) { __threadfence(); __syncthreads(); if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { // release mutex lock. atomicExch(user_params.mutexes_ + mutex_id, 0); } - } else { - __syncthreads(); - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - if (block_start_row_first_tile_ + row < extent_row_) { - gmem_ptr[row] = shared_elem_arr[row]; - } - } } } } From ad2ce755d6e266d59ee1443f829021723a78e9bc Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 12 May 2023 04:52:25 -0700 Subject: [PATCH 43/48] use the new dispatch mechanism to select the appropriate kernel at runtime --- .../raft/distance/detail/fused_l2_nn.cuh | 96 +++++++++++-------- 1 file changed, 54 insertions(+), 42 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 94c6a3a3bd..19fa6fee2d 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -25,6 +25,7 @@ #include // PairwiseDistances #include // Policy #include // raft::ceildiv, raft::shfl +#include // raft::util::arch::SM_* namespace raft { namespace distance { @@ -153,6 +154,8 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, OpT distance_op, FinalLambda fin_op) { +// compile only if below non-ampere arch. +#if __CUDA_ARCH__ < 800 extern __shared__ char smem[]; typedef KeyValuePair KVPair; @@ -248,6 +251,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, fin_op, rowEpilog_lambda); obj.run(); +#endif } // cg::reduce functor for FusedL2NN used in its cutlass version @@ -304,9 +308,31 @@ void fusedL2NNImpl(OutT* min, RAFT_CUDA_TRY(cudaGetLastError()); } - const auto deviceVersion = getComputeCapability(); - - if (deviceVersion.first >= 8) { + namespace arch = raft::util::arch; + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedL2NNkernel; + + // Get pointer to SM60 kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. using L2Op = raft::distance::detail::ops::l2_exp_cutlass_op; using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; kvp_cg_min_reduce_op_ cg_reduce_op; @@ -316,47 +342,33 @@ void fusedL2NNImpl(OutT* min, lda = k, ldb = k, ldd = n; cutlassFusedDistanceNN(x, - y, - xn, - yn, - m, - n, - k, - lda, - ldb, - ldd, - min, - workspace, - cg_reduce_op, - L2_dist_op, - redOp, - pairRedOp, - stream); + DataT, + OutT, + IdxT, + P::Veclen, + kvp_cg_min_reduce_op_, + L2Op, + ReduceOpT, + KVPReduceOpT>(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + L2_dist_op, + redOp, + pairRedOp, + stream); } else { + // If device less than SM_80, use fp32 SIMT kernel. constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - - using AccT = DataT; - ops::l2_exp_distance_op distance_op{sqrt}; - - raft::identity_op fin_op{}; - - auto kernel = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); kernel<<>>( From 6b87cc9fd6f40b047e9c50d08393ad62f3574aa1 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 12 May 2023 06:13:01 -0700 Subject: [PATCH 44/48] fix comments and small cleanup --- cpp/include/raft/distance/detail/fused_distance_nn/gemm.h | 8 ++++---- .../distance/detail/fused_distance_nn/persistent_gemm.h | 5 ++--- .../predicated_tile_iterator_normvec_smem.h | 5 ++--- cpp/include/raft/distance/detail/fused_l2_nn.cuh | 6 ++++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h index bdb18c4af2..5ed39864ea 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h @@ -100,16 +100,16 @@ struct FusedDistanceNNGemm { // this shape has high occupancy but less perf // this is less performant but this shape has *no* register spills // for any veclens(1, 2, 4) - // using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; + //using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute // <- warp tile M = 64, N = 64, K = 16 - using WarpShape = - cutlass::gemm::GemmShape<32, 64, 16>; // this is more performant for grouped GEMM + // this is more performant for veclen 2,4. + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // this shape has high occupancy but less perf used for 32x128x16 - // using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + //using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index e48c57be99..5a32978775 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -381,10 +381,9 @@ struct FusedDistanceNNPersistent { using LayoutC = typename Epilogue::OutputTileIterator::Layout; const GemmCoord& problem_size = params.problem_size; - const uint32_t problem_chunk = - (tile_count(grid_shape(problem_size)) - 1 + gridDim.x) / gridDim.x; + const auto grid_shape_ = grid_shape(problem_size); + const uint32_t problem_chunk = (tile_count(grid_shape_) - 1 + gridDim.x) / gridDim.x; const uint32_t problem_chunk_end = blockIdx.x * problem_chunk + problem_chunk; - const auto grid_shape_ = grid_shape(problem_size); typename LayoutB::Index column = ((blockIdx.x * problem_chunk) % grid_shape_.column()) * Mma::Shape::kN; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h index 244e723ae2..7a35aa328f 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h @@ -184,14 +184,14 @@ class PredicatedTileIteratorNormVecSmem { }; /// Shared storage allocation needed by the predicated tile - // iterator for storing rowNorm chunk of di. + // iterator for storing rowNorm chunk. struct SharedStorage { // // Type definitions // using Shape = MatrixShape; - /// Shape of the shared memory allocation for the reduced values store + /// Shape of the shared memory allocation using StorageShape = MatrixShape; // @@ -237,7 +237,6 @@ class PredicatedTileIteratorNormVecSmem { /// Extent of the matrix tile in rows Index extent_row_; - // Index block_start_row_first_tile_; /// Extent of the matrix tile in rows Index extent_column_; diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 19fa6fee2d..4d0b6569ca 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -254,8 +254,10 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, #endif } -// cg::reduce functor for FusedL2NN used in its cutlass version +// cg::reduce functor for FusedDistanceNN used in its cutlass version // to output the min distance value & key(loc id). +// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h +// store_with_byte_offset() passed to cg::reduce() & select_reduce. template struct kvp_cg_min_reduce_op { typedef typename raft::KeyValuePair KVP; @@ -323,7 +325,7 @@ void fusedL2NNImpl(OutT* min, decltype(distance_op), decltype(fin_op)>; - // Get pointer to SM60 kernel to determine the runtime architecture of the + // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the // current system. Other methods to determine the architecture (that do not // require a pointer) can be error prone. See: // https://github.com/NVIDIA/cub/issues/545 From 383f5bd751864ae0f76b85143716fb9c64818ef5 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 12 May 2023 06:20:28 -0700 Subject: [PATCH 45/48] add comment about persistent_gemm.h mapping to its cutlass version --- .../distance/detail/fused_distance_nn/persistent_gemm.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index 5a32978775..9dd9e06a61 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -46,6 +46,12 @@ /*! \file \brief Problem visitor for grouped GEMMs +This file contains heavily customized version of GemmGrouped from CUTLASS 2.10.0 +(https://github.com/NVIDIA/cutlass/blob/v2.10.0/include/cutlass/gemm/kernel/gemm_grouped.h) + +Changes: +- adds support for only single problem size to be launched persistently + where each threablock processes more than one tile of the same problem. */ #pragma once From 3f4ceffbf12c94114f5d92aa86f70bfb952d8d44 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 12 May 2023 06:23:32 -0700 Subject: [PATCH 46/48] fix all formatting issues --- .../distance/detail/fused_distance_nn/gemm.h | 26 ++++----- .../fused_distance_nn/persistent_gemm.h | 22 ++++---- .../predicated_tile_iterator_normvec_smem.h | 12 +++-- .../predicated_tile_iterator_reduced_vec.h | 9 ++-- .../raft/distance/detail/fused_l2_nn.cuh | 54 +++++++++---------- 5 files changed, 64 insertions(+), 59 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h index 5ed39864ea..3da8b3ee3d 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h @@ -63,11 +63,11 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// /* -* This configuration is used for float inputs with veclen(kAlignmentA/B) = 2 or 4, -* ideal threadblock tile shape is 32x256x16 for such cases as there is no -* registers spills for it. -* -*/ + * This configuration is used for float inputs with veclen(kAlignmentA/B) = 2 or 4, + * ideal threadblock tile shape is 32x256x16 for such cases as there is no + * registers spills for it. + * + */ template < /// Element type for A matrix operand typename ElementA_, @@ -100,16 +100,16 @@ struct FusedDistanceNNGemm { // this shape has high occupancy but less perf // this is less performant but this shape has *no* register spills // for any veclens(1, 2, 4) - //using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; + // using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes tile size a warp will compute // <- warp tile M = 64, N = 64, K = 16 // this is more performant for veclen 2,4. - using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; // this shape has high occupancy but less perf used for 32x128x16 - //using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + // using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; /// Warp-level tile size (concept: GemmShape) // This code section describes the size of MMA op @@ -186,11 +186,11 @@ struct FusedDistanceNNGemm { }; /* -* This configuration is used for float inputs with veclen(kAlignmentA/B) = 1, -* ideal threadblock tile shape is 32x128x16 for such cases as there is no -* registers spills for it. -* -*/ + * This configuration is used for float inputs with veclen(kAlignmentA/B) = 1, + * ideal threadblock tile shape is 32x128x16 for such cases as there is no + * registers spills for it. + * + */ template < /// Element type for C and D matrix operands typename ElementC_, diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index 9dd9e06a61..3a8d6c8655 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -50,7 +50,7 @@ This file contains heavily customized version of GemmGrouped from CUTLASS 2.10.0 (https://github.com/NVIDIA/cutlass/blob/v2.10.0/include/cutlass/gemm/kernel/gemm_grouped.h) Changes: -- adds support for only single problem size to be launched persistently +- adds support for only single problem size to be launched persistently where each threablock processes more than one tile of the same problem. */ @@ -379,16 +379,16 @@ struct FusedDistanceNNPersistent { // These types shadow the type-level definitions and support the ability to implement // a 'transposed' GEMM that computes the transposed problems. // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - - const GemmCoord& problem_size = params.problem_size; - const auto grid_shape_ = grid_shape(problem_size); - const uint32_t problem_chunk = (tile_count(grid_shape_) - 1 + gridDim.x) / gridDim.x; + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + const GemmCoord& problem_size = params.problem_size; + const auto grid_shape_ = grid_shape(problem_size); + const uint32_t problem_chunk = (tile_count(grid_shape_) - 1 + gridDim.x) / gridDim.x; const uint32_t problem_chunk_end = blockIdx.x * problem_chunk + problem_chunk; typename LayoutB::Index column = ((blockIdx.x * problem_chunk) % grid_shape_.column()) * Mma::Shape::kN; diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h index 7a35aa328f..c35a64f105 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h @@ -207,10 +207,14 @@ class PredicatedTileIteratorNormVecSmem { SharedStorage() {} CUTLASS_DEVICE - void initSmem(void *pointer, const Index &num_rows, const Index &tb_row_offset, const LongIndex &stride) { - Element* shared_elem_arr = data(); - uint8_t* first_tile_byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(tb_row_offset) * LongIndex(stride); + void initSmem(void* pointer, + const Index& num_rows, + const Index& tb_row_offset, + const LongIndex& stride) + { + Element* shared_elem_arr = data(); + uint8_t* first_tile_byte_pointer_ = + reinterpret_cast(pointer) + LongIndex(tb_row_offset) * LongIndex(stride); const auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 70d4d80c49..dc224c5c96 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -231,9 +231,10 @@ class PredicatedTileIteratorReducedVec { SharedStorage() {} CUTLASS_DEVICE - void initSmem(EpilogueOpParams const& user_params) { - Element* shared_elem_arr = data(); - constexpr auto maxVal = std::numeric_limits::max(); + void initSmem(EpilogueOpParams const& user_params) + { + Element* shared_elem_arr = data(); + constexpr auto maxVal = std::numeric_limits::max(); for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { user_params.red_op_.init(&shared_elem_arr[row], maxVal); @@ -434,7 +435,7 @@ class PredicatedTileIteratorReducedVec { EpilogueOpParams const& user_params = params_.user_param; auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); Element* shared_elem_arr = shared_storage_.data(); - const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); + const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); // If this is not optimal grid size perform mutex based gmem reduce. if (useGmemMutex) { diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 4d0b6569ca..2ff8fa7f1c 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -24,8 +24,8 @@ #include #include // PairwiseDistances #include // Policy -#include // raft::ceildiv, raft::shfl #include // raft::util::arch::SM_* +#include // raft::ceildiv, raft::shfl namespace raft { namespace distance { @@ -311,7 +311,7 @@ void fusedL2NNImpl(OutT* min, } namespace arch = raft::util::arch; - using AccT = DataT; + using AccT = DataT; ops::l2_exp_distance_op distance_op{sqrt}; raft::identity_op fin_op{}; @@ -344,34 +344,34 @@ void fusedL2NNImpl(OutT* min, lda = k, ldb = k, ldd = n; cutlassFusedDistanceNN(x, - y, - xn, - yn, - m, - n, - k, - lda, - ldb, - ldd, - min, - workspace, - cg_reduce_op, - L2_dist_op, - redOp, - pairRedOp, - stream); + DataT, + OutT, + IdxT, + P::Veclen, + kvp_cg_min_reduce_op_, + L2Op, + ReduceOpT, + KVPReduceOpT>(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + L2_dist_op, + redOp, + pairRedOp, + stream); } else { // If device less than SM_80, use fp32 SIMT kernel. constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); kernel<<>>( min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); From 5f8f33bdc81ea81af883086215092757c9ec5a33 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 16 May 2023 09:11:36 -0700 Subject: [PATCH 47/48] remove dead/commented code from epilogue broadcast header --- .../detail/fused_distance_nn/custom_epilogue_with_broadcast.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index 585cd14cd5..c2295ecd55 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -357,9 +357,6 @@ class EpilogueWithBroadcastCustom : public EpilogueBase Date: Tue, 16 May 2023 09:23:21 -0700 Subject: [PATCH 48/48] fix formatting --- .../detail/fused_distance_nn/custom_epilogue_with_broadcast.h | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h index c2295ecd55..10827a8778 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -357,7 +357,6 @@ class EpilogueWithBroadcastCustom : public EpilogueBase