From a1d1fd68b7c77d45ef476cdc6c5c1c465d9b8c46 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade <36705640+mdoijade@users.noreply.github.com> Date: Wed, 17 May 2023 03:34:50 +0530 Subject: [PATCH] Fused L2 1-NN based on cutlass 3xTF32 / DMMA (#1118) -- 3xTF32 & DMMA cutlass based persistent FusedL2NN kernel version loosely based on grouped gemm but customized for single problem size. -- as the value of `k` increases the performance benefit of this implementation gets better. for k==64 upto 1.3x, for k ==128 upto 1.53x, k == 256, up to 1.67x. -- for all the sizes of `k` this kernel out performs previous implementation. -- attaching the results of FusedL2NN Benchmark of previous implementation with this cutlass version. Authors: - Mahesh Doijade (https://github.com/mdoijade) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1118 --- cpp/cmake/thirdparty/get_cutlass.cmake | 4 +- .../custom_epilogue_with_broadcast.h | 671 ++++++++++++++++++ .../detail/fused_distance_nn/cutlass_base.cuh | 161 +++++ .../detail/fused_distance_nn/epilogue.cuh | 136 ++++ .../epilogue_elementwise.cuh | 216 ++++++ .../distance/detail/fused_distance_nn/gemm.h | 410 +++++++++++ .../fused_distance_nn/persistent_gemm.h | 515 ++++++++++++++ .../predicated_tile_iterator_normvec_smem.h | 448 ++++++++++++ .../predicated_tile_iterator_reduced_vec.h | 626 ++++++++++++++++ .../raft/distance/detail/fused_l2_nn.cuh | 110 ++- .../detail/pairwise_distance_cutlass_base.cuh | 23 +- .../detail/predicated_tile_iterator_normvec.h | 14 +- .../neighbors/detail/connect_components.cuh | 20 +- cpp/include/raft/util/cutlass_utils.cuh | 53 ++ cpp/test/distance/fused_l2_nn.cu | 6 +- 15 files changed, 3369 insertions(+), 44 deletions(-) create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/gemm.h create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h create mode 100644 cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h create mode 100644 cpp/include/raft/util/cutlass_utils.cuh diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index cb809de445..853fd7c52f 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -78,7 +78,7 @@ function(find_and_configure_cutlass) endfunction() if(NOT RAFT_CUTLASS_GIT_TAG) - set(RAFT_CUTLASS_GIT_TAG v2.9.1) + set(RAFT_CUTLASS_GIT_TAG v2.10.0) endif() if(NOT RAFT_CUTLASS_GIT_REPOSITORY) @@ -86,5 +86,5 @@ if(NOT RAFT_CUTLASS_GIT_REPOSITORY) endif() find_and_configure_cutlass( - VERSION 2.9.1 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG ${RAFT_CUTLASS_GIT_TAG} + VERSION 2.10.0 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG ${RAFT_CUTLASS_GIT_TAG} ) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h new file mode 100644 index 0000000000..10827a8778 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h @@ -0,0 +1,671 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +This file contains a customized version of EpilogueWithBroadcast from CUTLASS 2.9.1 +(https://github.com/NVIDIA/cutlass/blob/v2.9.1/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h) + +Changes: +- customized the compute_source_needed_() and apply_output_operator_() to suit the needs of per row +reduction +*/ + +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#include +#else +#include +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template +struct EpilogueWithBroadcastOpBaseCustom { + using ElementOutput = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using FragmentT = Array; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = StoreZ; + + /// If true, the 'T' tensor is stored + static bool const kStoreT = StoreT; + + /// Parameters structure - required + struct Params {}; + + // + // Methods + // + + /// Constructor from Params + EpilogueWithBroadcastOpBaseCustom(Params const& params_) {} + + /// Determine if the source is needed. May return false if + bool is_source_needed() const { return true; } + + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentC const& frag_C, + FragmentCompute const& V) const + { + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentCompute const& V) const + { + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator with bias vector broadcast over columns. +/// +/// Computes the following: +/// +/// +/// Z, T = OutputOp(AB, C, Broadcast) +/// +/// if (ElementwiseOp::kStoreZ) { +/// store(converted_u); +/// } +/// +/// if (ElementwiseOp::kStoreT) { +/// store(v); +/// } +/// +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) + typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) + typename ElementVector_, ///< Pointer to broadcast vector + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: + ///< MatrixShape) + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value)> +class EpilogueWithBroadcastCustom : public EpilogueBase { + public: + using Base = EpilogueBase; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using TensorTileIterator = TensorTileIterator_; + using ElementVector = ElementVector_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Compute data type produced by the output op + using ElementCompute = typename OutputOp::ElementCompute; + + /// Compute fragment + using FragmentCompute = Array; + + /// Thread map used by output tile iterators + using ThreadMap = typename OutputTileIterator::ThreadMap; + + /// Fragment object used to store the broadcast values + using BroadcastFragment = + Array; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Data type of additional tensor + using ElementTensor = typename TensorTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = + Array; + + /// Array type used by output functor + using AccumulatorAccessType = + Array; + + /// Array type used by output functor + using ComputeAccessType = Array; + + /// Tensor access type + using TensorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + /// Shared memory allocation from epilogue base class + using BaseSharedStorage = typename Base::SharedStorage; + + static int constexpr kSmemTiles = + Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + /// Used for the broadcast + struct BroadcastDetail { + /// Number of threads per warp + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar column indices handled by each thread + static int const kColumnsPerThread = + ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; + + /// Number of distinct scalar row indices handled by each thread + static int const kRowsPerThread = + ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; + + /// Number of threads per threadblock + static int const kThreadCount = kWarpSize * WarpCount::kCount; + + /// Number of distinct threads per row of output tile + static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); + + /// Number of distinct threads which must be reduced during the final reduction phase within the + /// threadblock. + static int const kThreadRows = kThreadCount / kThreadsPerRow; + + /// I'm not sure what I meant here. + static int const kThreadAccessesPerRow = + const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = MatrixShape; + + /// Debug printing + CUTLASS_DEVICE + static void print() + { +#if 0 + printf("BroadcastDetail {\n"); + printf( + " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" + "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", + kColumnsPerThread, + kRowsPerThread, + kThreadCount, + kThreadsPerRow, + kThreadRows, + kThreadAccessesPerRow, + StorageShape::kRow, + StorageShape::kColumn, + StorageShape::kCount + ); + printf("};\n"); +#endif + } + }; + + /// Shared storage structure (shadows base) with additional SMEM buffer for reduction + struct SharedStorage { + union { + BaseSharedStorage base; + }; + + CUTLASS_HOST_DEVICE + SharedStorage() {} + }; + + public: + static_assert(SharedLoadIterator::Fragment::kElements == TensorTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + /// Thread index within the threadblock + int thread_idx_; + + public: + /// Constructor + CUTLASS_DEVICE + EpilogueWithBroadcastCustom(SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage.base, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.base.reference(), thread_idx), + thread_idx_(thread_idx) + { + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + ElementVector const* broadcast_ptr, ///< Broadcast vector + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix + TensorTileIterator + tensor_iterator, ///< Threadblock tile iterator for additional tensor operand + MatrixCoord const& + problem_size = ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord(Shape::kM, Shape::kN), + MatrixCoord const& + threadblock_offset = ///< Threadblock's initial offset within the problem size space + MatrixCoord()) + { + BroadcastFragment broadcast_fragment; + + load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); + + compute_source_needed_( + output_op, broadcast_fragment, accumulators, source_iterator, tensor_iterator); + } + + private: + CUTLASS_DEVICE + void load_broadcast_fragment_( + BroadcastFragment& + broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + ElementVector const* broadcast_ptr, ///< Broadcast vector + MatrixCoord const& + problem_size, ///< Problem size needed to guard against out-of-bounds accesses + MatrixCoord const& + threadblock_offset ///< Threadblock's initial offset within the problem size space + ) + { + broadcast_fragment.clear(); + + // If no pointer is supplied, set with all zeros and avoid memory accesses + if (!broadcast_ptr) { return; } + + int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); + + int thread_column_idx = threadblock_offset.column() + thread_initial_column; + broadcast_ptr += thread_initial_column; + + NumericArrayConverter + converter; + using AccessType = AlignedArray; + using ComputeFragmentType = Array; + + ComputeFragmentType* frag_ptr = reinterpret_cast(&broadcast_fragment); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { + AccessType loaded; + + loaded.clear(); + + if (thread_column_idx < problem_size.column()) { + loaded = *reinterpret_cast(broadcast_ptr); + } + + ComputeFragmentType cvt = converter(loaded); + frag_ptr[j] = cvt; + + thread_column_idx += ThreadMap::Delta::kColumn; + broadcast_ptr += ThreadMap::Delta::kColumn; + } + } + + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * + (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) + { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + BroadcastFragment const& + broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + OutputTileIterator destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) + { + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) + { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + BroadcastFragment const& + broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns + AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile + OutputTileIterator + source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand + ) + { + typename OutputTileIterator::Fragment source_fragment; + source_fragment.clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + // + // Convert and store fragment + // + + //__syncthreads(); + + acc2smem_source_needed>::push( + iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // + // Apply output operation + // + + typename TensorTileIterator::Fragment frag_T; + + // + // Load the source + // + + source_iterator.load(source_fragment); + ++source_iterator; + + apply_output_operator_( + frag_T, output_op, aligned_accum_fragment[0], source_fragment, broadcast_fragment); + + // + // Conditionally store fragments + // + if (OutputOp::kStoreT) { + tensor_iterator.store(frag_T); + ++tensor_iterator; + } + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_(typename TensorTileIterator::Fragment& frag_T, + OutputOp const& output_op, + typename SharedLoadIterator::Fragment const& frag_AB, + typename OutputTileIterator::Fragment const& frag_C, + BroadcastFragment const& frag_Broadcast) + { + using AccessTypeT = Array; + using AccessTypeBroadcast = Array; + + AccessTypeT* frag_T_ptr = reinterpret_cast(&frag_T); + + AccumulatorAccessType const* frag_AB_ptr = + reinterpret_cast(&frag_AB); + + OutputAccessType const* frag_C_ptr = reinterpret_cast(&frag_C); + + AccessTypeBroadcast const* frag_Broadcast_ptr = + reinterpret_cast(&frag_Broadcast); + + int const kOutputOpIterations = + TensorTileIterator::Fragment::kElements / TensorTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + output_op(frag_T_ptr[i], + frag_AB_ptr[i], + frag_C_ptr[(i / ThreadMap::Iterations::kColumn)], + frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + typename OutputTileIterator::Fragment& frag_Z, + typename TensorTileIterator::Fragment& frag_T, + OutputOp const& output_op, + typename SharedLoadIterator::Fragment const& frag_AB, + BroadcastFragment const& frag_Broadcast) + { + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh new file mode 100644 index 0000000000..a1cf1a9b17 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wtautological-compare" + +// We define CUTLASS_NAMESPACE in case +// RAFT cmake is not used +#ifndef CUTLASS_NAMESPACE +#define cutlass raft_cutlass +#endif + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include // FusedDistanceNNEpilogueElementwise +#include // FusedDistanceNNGemm +#include // getMultiProcessorCount +#include // RAFT_CUTLASS_TRY + +namespace raft { +namespace distance { +namespace detail { + +template +void cutlassFusedDistanceNN(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + int* mutexes, + CGReduceOpT cg_reduce_op, + DistanceFn dist_op, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + cudaStream_t stream) +{ + using EpilogueOutputOp = cutlass::epilogue::thread::FusedDistanceNNEpilogueElementwise< + DataT, // ElementC_ + AccT, // ElementAccumulator_ + DataT, // ElementCompute_ + AccT, // ElementZ_ + OutT, // ElementT_ + // 128 / cutlass::sizeof_bits::value, + 1, // Elements per access 1 + DistanceFn, + CGReduceOpT, + ReduceOpT, + KVPReduceOpT>; + constexpr int batch_count = 1; + + typename EpilogueOutputOp::Params epilog_op_param( + dist_op, cg_reduce_op, redOp, pairRedOp, mutexes); + + // Number of pipelines you want to use + constexpr int NumStages = 3; + // Alignment + constexpr int Alignment = VecLen; + + // default initialize problem size with row major inputs + auto problem_size = cutlass::gemm::GemmCoord(m, n, k); + + constexpr bool isRowMajor = true; + + using fusedDistanceNNKernel = + typename cutlass::gemm::kernel::FusedDistanceNNGemm::GemmKernel; + + using fusedDistanceNN = cutlass::gemm::device::GemmGrouped; + + int num_blocks_per_sm = fusedDistanceNN::maximum_active_blocks(); + int num_sms = raft::getMultiProcessorCount(); + int full_wave = num_blocks_per_sm * num_sms; + constexpr int mmaShapeM = fusedDistanceNNKernel::Mma::Shape::kM; + constexpr int mmaShapeN = fusedDistanceNNKernel::Mma::Shape::kN; + int columnTiles = (problem_size.n() - 1 + mmaShapeN) / mmaShapeN; + int rowTiles = (problem_size.m() - 1 + mmaShapeM) / mmaShapeM; + int totalTiles = columnTiles * rowTiles; + int thread_blocks = + rowTiles < full_wave ? (totalTiles < full_wave ? totalTiles : full_wave) : rowTiles; + + typename fusedDistanceNN::Arguments arguments{ + problem_size, + batch_count, // num of problems. + thread_blocks, + epilog_op_param, + x, + y, + xn, // C matrix eq vector param, which here is A norm + (DataT*)yn, // this is broadcast vec, which is required to be non-const param + dOutput, // Output distance matrix + (int64_t)lda, // stride A + (int64_t)ldb, // stride B + (int64_t)1, // stride A norm + (int64_t)ldd // stride Output matrix + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = fusedDistanceNN::get_workspace_size(arguments); + // Allocate workspace memory + rmm::device_uvector workspace(workspace_size, stream); + // Instantiate CUTLASS kernel depending on templates + fusedDistanceNN fusedDistanceNN_op; + // Check the problem size is supported or not + RAFT_CUTLASS_TRY(fusedDistanceNN_op.can_implement(arguments)); + // Initialize CUTLASS kernel with arguments and workspace pointer + RAFT_CUTLASS_TRY(fusedDistanceNN_op.initialize(arguments, workspace.data(), stream)); + // Launch initialized CUTLASS kernel + RAFT_CUTLASS_TRY(fusedDistanceNN_op.run(stream)); +} + +}; // namespace detail +}; // namespace distance +}; // namespace raft + +#pragma GCC diagnostic pop diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh new file mode 100644 index 0000000000..8a0bea3469 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75) + +This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec +and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise +operation. +-- A norm load is provided PredicatedTileIteratorNormVec +-- B norm load is provided by EpilogueWithBroadcast +-- elementwise operation is provided by OutputOp +*/ + +#pragma once + +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template +struct FusedDistanceNNEpilogue { + /// Use defaults related to the existing epilogue + using Base = + DefaultEpilogueTensorOp; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using RowNormTileIterator = cutlass::epilogue::threadblock:: + PredicatedTileIteratorNormVecSmem; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorReducedVec< + typename Base::OutputTileThreadMap, + ElementTensor, + LayoutT, + typename OutputOp::Params>; + + /// Define the epilogue + using Epilogue = cutlass::epilogue::threadblock::EpilogueWithBroadcastCustom< + Shape, + WarpMmaTensorOp, + PartitionsK, + RowNormTileIterator, + OutputTileIterator, + ElementVector, + typename Base::AccumulatorFragmentIterator, + typename Base::WarpTileIterator, + typename Base::SharedLoadIterator, + OutputOp, + typename Base::Padding, + Base::kFragmentsPerIteration>; +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh new file mode 100644 index 0000000000..a21f3d60e0 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh @@ -0,0 +1,216 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// +/*! \file + \brief Functor performing distance operations used by epilogues of pairwise distance + * kernels. +* This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0 +* customized for applying elementwise distance formula on accumulated GEMM value +* and applying user-defined operation which can convert distance values to key-value pair. +* . +*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This base class is meant to define the concept required of the +/// EpilogueWithBroadcast::OutputOp +template +class FusedDistanceNNEpilogueElementwise { + public: + using ElementOutput = ElementC_; + using ElementC = ElementC_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementZ = ElementZ_; + using ElementT = ElementT_; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + + using DistanceOp = DistanceOp_; + using CGReduceOp = CGReduceOp_; + + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentC = Array; + using FragmentZ = Array; + using OutValT = typename CGReduceOp::AccTypeT; + using FragmentT = Array; + + using FragmentOutput = FragmentZ; + + static bool const kIsHeavy = true; // ElementwiseOp::kIsHeavy; + + /// If true, the 'Z' tensor is stored + static bool const kStoreZ = false; // We don't store anything in Z, + + /// If true, the 'T' tensor is stored + static bool const kStoreT = true; // this is our final output storage. + + /// Host-constructable parameters structure + struct Params { + CGReduceOp_ cg_reduce_op; + DistanceOp_ dist_op_; + KVPReduceOpT_ pair_redop_; + ReduceOpT_ red_op_; + int* mutexes_; + using CGReduceT = CGReduceOp_; + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(DistanceOp_ dist_op, + CGReduceOp cg_reduce_op, + ReduceOpT_ red_op, + KVPReduceOpT_ pair_redop, + int* mutexes) + : cg_reduce_op(cg_reduce_op), + dist_op_(dist_op), + pair_redop_(pair_redop), + red_op_(red_op), + mutexes_(mutexes) + { + } + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + // + // Data members + // + DistanceOp_ elementwise_op; + KVPReduceOpT_ pair_redop; + + public: + ReduceOpT_ red_op; + + // + // Methods + // + + /// Constructor from Params + CUTLASS_HOST_DEVICE + FusedDistanceNNEpilogueElementwise(Params const& params) + : elementwise_op(params.dist_op_), pair_redop(params.pair_redop_), red_op(params.red_op_) + { + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const + { + // we use for making sure C matrix is used for A mat norm. + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Applies the operation when is_source_needed() is true + CUTLASS_HOST_DEVICE + void operator()(FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentC const& frag_C, + FragmentCompute const& V) const + { + FragmentCompute tmp_Accum = + NumericArrayConverter()(AB); + FragmentCompute tmp_C = + NumericArrayConverter()(frag_C); + FragmentCompute result_Z; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerAccess; ++i) { + ElementCompute res_Z = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); + frag_T[i] = res_Z; + } + } + + /// Applies the operation when is_source_needed() is false + CUTLASS_HOST_DEVICE + void operator()(FragmentZ& frag_Z, + FragmentT& frag_T, + FragmentAccumulator const& AB, + FragmentCompute const& V) const + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h new file mode 100644 index 0000000000..3da8b3ee3d --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h @@ -0,0 +1,410 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include + +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/* + * This configuration is used for float inputs with veclen(kAlignmentA/B) = 2 or 4, + * ideal threadblock tile shape is 32x256x16 for such cases as there is no + * registers spills for it. + * + */ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct FusedDistanceNNGemm { + // This struct is specialized for fp32/3xTF32 + + /// Threadblock-level tile size (concept: GemmShape) + // <- threadblock tile M = 32, N = 256, K = 16 + // this is more performant but note that for veclen = 1 + // this shape has register spills + using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 16>; + + // <- threadblock tile M = 32, N = 128, K = 16 + // this shape has high occupancy but less perf + // this is less performant but this shape has *no* register spills + // for any veclens(1, 2, 4) + // using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; + + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + // <- warp tile M = 64, N = 64, K = 16 + // this is more performant for veclen 2,4. + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + + // this shape has high occupancy but less perf used for 32x128x16 + // using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + // <- MMA Op tile M = 16, N = 8, K = 4 + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastF32; + // using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementAccumulator, + typename EpilogueOutputOp::ElementT, + ElementAccumulator, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = FusedDistanceNNPersistent; +}; + +/* + * This configuration is used for float inputs with veclen(kAlignmentA/B) = 1, + * ideal threadblock tile shape is 32x128x16 for such cases as there is no + * registers spills for it. + * + */ +template < + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct FusedDistanceNNGemm { + // This struct is specialized for fp32/3xTF32 + using ElementA_ = float; + using ElementB_ = float; + + /// Threadblock-level tile size (concept: GemmShape) + // <- threadblock tile M = 32, N = 128, K = 16 + // this shape has high occupancy and no register spills for veclen = 1. + using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; + + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + // <- warp tile M = 32, N = 32, K = 16 + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + // <- MMA Op tile M = 16, N = 8, K = 4 + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + /// Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAddFastF32; + // using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementAccumulator, + typename EpilogueOutputOp::ElementT, + ElementAccumulator, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = FusedDistanceNNPersistent; +}; + +template < + /// Layout type for A matrix operand + int kAlignmentA, + /// Layout type for B matrix operand + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' + typename EpilogueOutputOp, + /// Number of stages used in the pipelined mainloop + int Stages, + /// data layout row/column major of inputs + bool isRowMajor> +struct FusedDistanceNNGemm { + // Threadblock-level tile size (concept: GemmShape) + // <- threadblock tile M = 64, N = 64, K = 16 + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; + // using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; + /// Warp-level tile size (concept: GemmShape) + // This code section describes tile size a warp will compute + // <- warp tile M = 32, N = 32, K = 16 + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + // using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; + /// Warp-level tile size (concept: GemmShape) + // This code section describes the size of MMA op + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + // Operation performed by GEMM + using Operator = cutlass::arch::OpMultiplyAdd; + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU + // SM + using OperatorClass = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using ArchTag = cutlass::arch::Sm80; + + // This code section describes how threadblocks are scheduled on GPU + /// Threadblock-level swizzling operator + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /// data layout for final output matrix. + // we keep this same layout even for column major inputs + using LayoutOutput = cutlass::layout::RowMajor; + + typedef typename std::conditional::type NormXLayout; + + typedef typename std:: + conditional::type LayoutA_; + + typedef typename std:: + conditional::type LayoutB_; + + using GemmBase = typename DefaultGemmUniversal::GemmKernel; + + // Replace epilogue + using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< + typename GemmBase::Epilogue::Shape, + typename GemmBase::Epilogue::WarpMmaOperator, + GemmBase::Epilogue::kPartitionsK, + ElementC_, + typename EpilogueOutputOp::ElementT, + ElementC_, + EpilogueOutputOp, + NormXLayout, + GemmBase::Epilogue::kElementsPerAccess>::Epilogue; + + // Compose the GEMM kernel + using GemmKernel = FusedDistanceNNPersistent; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h new file mode 100644 index 0000000000..3a8d6c8655 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -0,0 +1,515 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Problem visitor for grouped GEMMs +This file contains heavily customized version of GemmGrouped from CUTLASS 2.10.0 +(https://github.com/NVIDIA/cutlass/blob/v2.10.0/include/cutlass/gemm/kernel/gemm_grouped.h) + +Changes: +- adds support for only single problem size to be launched persistently + where each threablock processes more than one tile of the same problem. +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FusedDistanceNNPersistent { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor = GemmGroupedProblemVisitor; + + // + // Structures + // + + struct temp_problem_visitor { + int problem_count; + + CUTLASS_HOST_DEVICE temp_problem_visitor() : problem_count(0){}; + CUTLASS_HOST_DEVICE temp_problem_visitor(int problem_count_) : problem_count(problem_count_){}; + }; + + /// Argument structure + struct Arguments { + // + // Data members + // + GemmCoord problem_sizes; + temp_problem_visitor problem_visitor; + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + void const* ptr_A; + void const* ptr_B; + void const* ptr_C; + void* ptr_Vector; + void* ptr_Tensor; + + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldt; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : // problem_count(0), + threadblock_count(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_Vector(nullptr), + ptr_Tensor(nullptr), + lda(0), + ldb(0), + ldc(0), + ldt(0), + host_problem_sizes(nullptr) + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord problem_sizes, + int problem_count, + int threadblock_count, + typename EpilogueOutputOp::Params output_op, + void const* ptr_A, + void const* ptr_B, + void const* ptr_C, + void* ptr_Vector, + void* ptr_Tensor, + typename LayoutA::Stride::Index lda, + typename LayoutB::Stride::Index ldb, + typename LayoutC::Stride::Index ldc, + typename LayoutC::Stride::Index ldt, + GemmCoord* host_problem_sizes = nullptr) + : problem_sizes(problem_sizes), + threadblock_count(threadblock_count), + output_op(output_op), + ptr_A(ptr_A), + ptr_B(ptr_B), + ptr_C(ptr_C), + ptr_Vector(ptr_Vector), + ptr_Tensor(ptr_Tensor), + lda(lda), + ldb(ldb), + ldc(ldc), + ldt(ldt), + host_problem_sizes(host_problem_sizes) + { + problem_visitor.problem_count = problem_count; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + // typename ProblemVisitor::Params problem_visitor; + temp_problem_visitor problem_visitor; + int threadblock_count; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::TensorTileIterator::Params params_Tensor; + + typename EpilogueOutputOp::Params output_op; + + void* ptr_A; + void* ptr_B; + void* ptr_C; + void* ptr_Vector; + void* ptr_Tensor; + + GemmCoord problem_size; + typename LayoutA::Stride::Index lda; + typename LayoutB::Stride::Index ldb; + typename LayoutC::Stride::Index ldc; + typename LayoutC::Stride::Index ldt; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : params_A(0), + params_B(0), + params_C(0), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_Vector(nullptr), + ptr_Tensor(nullptr), + lda(0), + ldb(0), + ldc(0), + ldt(0) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_size(args.problem_sizes), + threadblock_count(args.threadblock_count), + output_op(args.output_op), + params_A(args.lda), + params_B(args.ldb), + params_C(args.ldc), + // Here we pass additional user args via args.output_op + // to the reduction output tile iterator + params_Tensor(args.ldt, args.output_op), + ptr_A(const_cast(args.ptr_A)), + ptr_B(const_cast(args.ptr_B)), + ptr_C(const_cast(args.ptr_C)), + ptr_Vector(args.ptr_Vector), + ptr_Tensor(args.ptr_Tensor), + lda(args.lda), + ldb(args.ldb), + ldc(args.ldc), + ldt(args.ldt) + { + problem_visitor.problem_count = args.problem_visitor.problem_count; + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + { + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = const_cast(args.ptr_A); + ptr_B = const_cast(args.ptr_B); + ptr_C = const_cast(args.ptr_C); + ptr_Vector = args.ptr_Vector; + ptr_Tensor = args.ptr_Tensor; + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldt = args.ldt; + + problem_size = args.problem_sizes; + } + }; + + /// Shared memory storage structure + struct SharedStorage { + union { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + typename Epilogue::TensorTileIterator::SharedStorage reduced_store; + typename Epilogue::OutputTileIterator::SharedStorage rownorm_store; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + FusedDistanceNNPersistent() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { return Status::kSuccess; } + + static size_t get_extra_workspace_size(Arguments const& args, + cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + return 0; + } + + CUTLASS_DEVICE + static uint32_t tile_count(const cutlass::MatrixCoord& grid) + { + return grid.row() * grid.column(); + } + + /// Get the grid shape + CUTLASS_DEVICE + static cutlass::MatrixCoord grid_shape(const cutlass::gemm::GemmCoord& problem) + { + return cutlass::MatrixCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN)); + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if __CUDA_ARCH__ >= 800 + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + const GemmCoord& problem_size = params.problem_size; + const auto grid_shape_ = grid_shape(problem_size); + const uint32_t problem_chunk = (tile_count(grid_shape_) - 1 + gridDim.x) / gridDim.x; + const uint32_t problem_chunk_end = blockIdx.x * problem_chunk + problem_chunk; + typename LayoutB::Index column = + ((blockIdx.x * problem_chunk) % grid_shape_.column()) * Mma::Shape::kN; + + typename LayoutB::Index row = + ((blockIdx.x * problem_chunk) / grid_shape_.column()) * Mma::Shape::kM; + if (column) { + shared_storage.reduced_store.initSmem(params.output_op); + shared_storage.rownorm_store.initSmem(params.ptr_C, problem_size.m(), row, sizeof(ElementC)); + } + + // Outer 'persistent' loop to iterate over tiles + for (uint32_t tile_idx = blockIdx.x * problem_chunk; tile_idx < problem_chunk_end; tile_idx++) { + const auto grid_shape_ = grid_shape(problem_size); + cutlass::MatrixCoord threadblock_offset( + int(tile_idx / grid_shape_.column()) * Mma::Shape::kM, + int(tile_idx % grid_shape_.column()) * Mma::Shape::kN); + + const bool isNextTile = ((tile_idx + 1) < problem_chunk_end); + const bool doesRowChange = + ((threadblock_offset.column() + Mma::Shape::kN) >= problem_size.n()); + const bool do_gmem_reduce = (doesRowChange || !isNextTile) ? true : false; + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{threadblock_offset.row(), 0}; + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.column()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size.k(), problem_size.n()}, thread_idx, tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + //__syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = static_cast(params.ptr_C); + typename Epilogue::ElementTensor* ptr_Tensor = + static_cast(params.ptr_Tensor); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector* ptr_Vector = + static_cast(params.ptr_Vector); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_rownorm(shared_storage.rownorm_store, + params.params_C, + ptr_C, + problem_size.mn(), + thread_idx, + threadblock_offset); + + // Additional tensor to load from + typename Epilogue::TensorTileIterator tensor_iterator(shared_storage.reduced_store, + params.params_Tensor, + // Only the final block outputs Tensor + ptr_Tensor, + problem_size.mn(), + thread_idx, + do_gmem_reduce, + threadblock_offset); + + Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + // Move to appropriate location for this output tile + if (ptr_Vector) { ptr_Vector += threadblock_offset.column(); } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + ptr_Vector, + // iterator_D, + accumulators, + iterator_rownorm, + tensor_iterator, + problem_size.mn(), + threadblock_offset); + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h new file mode 100644 index 0000000000..c35a64f105 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h @@ -0,0 +1,448 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) + +Changes: +- added `Layout_` template param +- Only the row index is used to load the data in load_with_byte_offset(). + This way the same normalization data is used across all columns in a row. + +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template +class PredicatedTileIteratorNormVecSmem { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = Layout_; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::Count::kTile * ThreadMap::Delta::kRow; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + /// Shared storage allocation needed by the predicated tile + // iterator for storing rowNorm chunk. + struct SharedStorage { + // + // Type definitions + // + using Shape = MatrixShape; + + /// Shape of the shared memory allocation + using StorageShape = MatrixShape; + + // + // Data members + // + // Methods + // + AlignedBuffer storage; + + CUTLASS_DEVICE + Element* data() { return storage.data(); } + + SharedStorage() {} + + CUTLASS_DEVICE + void initSmem(void* pointer, + const Index& num_rows, + const Index& tb_row_offset, + const LongIndex& stride) + { + Element* shared_elem_arr = data(); + uint8_t* first_tile_byte_pointer_ = + reinterpret_cast(pointer) + LongIndex(tb_row_offset) * LongIndex(stride); + const auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + bool guard = (tb_row_offset + row) < num_rows; + cutlass::arch::cp_async(shared_elem_arr + row, gmem_ptr + row, guard); + cutlass::arch::cp_async_wait<0>(); + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + + private: + // + // Methods + // + + protected: + SharedStorage& shared_storage_; + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorNormVecSmem(SharedStorage& shared_storage, + PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord& threadblock_offset, + int const* indices = nullptr) + : params_(params), indices_(indices), shared_storage_(shared_storage) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + return; + } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride); + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + if (threadblock_offset.column() == 0) { + shared_storage_.initSmem(pointer, extent_row_, threadblock_offset.row(), params_.stride); + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const + { + AccessType* frag_ptr = reinterpret_cast(&frag); + + Element* shared_elem_arr = shared_storage_.data(); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int iter_row = ((row_offset + thread_start_row_) % total_rows); + Element val = shared_elem_arr[iter_row]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementsPerAccess; ++i) { + (*frag_ptr)[frag_row_idx + i] = val; + } + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorNormVecSmem& operator++() + { + ++state_[0]; + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h new file mode 100644 index 0000000000..dc224c5c96 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -0,0 +1,626 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + +This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 +(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) + +Changes: +- added `Layout_` template param +- PredicatedTileIteratorParams() is customized to not stride by layout.stride(0). +- makes use of `SharedStorage` to store reduced values across warps to gmem in coalesced manner. +- customized the store_with_byte_offset() to perform reduction per row and write final value to +gmem. +- customized the Params() struct to take user inputs from epilogueOp params. + +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator +/// +template +class PredicatedTileIteratorReducedVec { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = Layout_; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + using EpilogueOpParams = EpilogueOpParams_; + using OutIdxT = typename EpilogueOpParams::CGReduceT::IndexT; + using OutValT = typename EpilogueOpParams::CGReduceT::AccTypeT; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); + static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); + static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); + static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); + static_assert(!UseCUDAStore, "UseCUDAStore path is not supported"); + + static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::Count::kTile * ThreadMap::Delta::kRow; + /// Fragment object + using Fragment = + Array; + + // Memory access size + using AccessType = AlignedArray; + using AccessTypeValT = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + EpilogueOpParams user_param; + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) + { + } + + CUTLASS_HOST_DEVICE + Params(Layout const& layout, EpilogueOpParams const& user_param_) + : PredicatedTileIteratorParams(int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()), + user_param(user_param_) + { + } + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + // static int const kCount = ThreadMap::Iterations::kColumn; + static int const kCount = ThreadMap::Iterations::kColumn * kElementsPerAccess; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { enable(); } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + /// Shared storage allocation needed by the predicated tile + // iterator for reduction. + struct SharedStorage { + // + // Type definitions + // + using Shape = MatrixShape; + + /// Shape of the shared memory allocation for the reduced values store + using StorageShape = MatrixShape; + + // + // Data members + + // + // Methods + // + AlignedBuffer storage; + + CUTLASS_DEVICE + Element* data() { return storage.data(); } + + SharedStorage() {} + + CUTLASS_DEVICE + void initSmem(EpilogueOpParams const& user_params) + { + Element* shared_elem_arr = data(); + constexpr auto maxVal = std::numeric_limits::max(); + + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + user_params.red_op_.init(&shared_elem_arr[row], maxVal); + } + } + }; + + template + struct select_reduce { + /// Performs warp level reduction and stores a reduced output to memory + CUTLASS_DEVICE + select_reduce(OutT value, + ValT prev_red_val, + cg_reduce_op_t reduce_op, + cg_group_t cg_warp_group, + OutT& shmem_ptr) + { + if (cg_warp_group.any(reduce_op.isAmin(value, prev_red_val))) { + OutT reduced_val = cg::reduce(cg_warp_group, value, reduce_op); + if (cg_warp_group.thread_rank() == 0) { shmem_ptr = reduced_val; } + } + } + }; + + template + struct select_reduce> { + using ValT = float; + using Ty = raft::KeyValuePair; + /// Performs warp level reduction of key value pair and stores a reduced output to memory + CUTLASS_DEVICE + select_reduce(Ty val_to_red, + float prev_red_val, + cg_reduce_op_t cg_reduce_op, + cg_group_t cg_warp_group, + Ty& shmem_ptr) + { + ValT val = val_to_red.value; + + if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { + ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); + bool pred = (reduced_val == val); + auto subTile = cg::binary_partition(cg_warp_group, pred); + if (pred) { + if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } + } + } + } + }; + + template + struct select_reduce> { + using ValT = double; + using Ty = raft::KeyValuePair; + /// Performs warp level reduction of key value pair and stores a reduced output to memory + CUTLASS_DEVICE + select_reduce(Ty val_to_red, + double prev_red_val, + cg_reduce_op_t cg_reduce_op, + cg_group_t cg_warp_group, + Ty& shmem_ptr) + { + ValT val = val_to_red.value; + + if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { + ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); + bool pred = (reduced_val == val); + auto subTile = cg::binary_partition(cg_warp_group, pred); + if (pred) { + if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } + } + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + /// Byte-level pointer first tile offset of this threadblock. + uint8_t* first_tile_byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + Index block_start_row_first_tile_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + // mutable int shared_tile_id; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(Params::stride) == 8, "Expected 64b strides"); + + protected: + SharedStorage& shared_storage_; + const bool& do_gmem_reduction_; + + private: + // + // Methods + // + public: + // + // Methods + // + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorReducedVec(SharedStorage& shared_storage, + Params const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + const bool& do_gmem_reduction, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), + indices_(indices), + shared_storage_(shared_storage), + do_gmem_reduction_(do_gmem_reduction) + { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + TensorCoord block_offset = ThreadMap::initial_offset(0) + threadblock_offset; + block_start_row_first_tile_ = block_offset.row(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++c) { + int columnPerAccess = (c / kElementsPerAccess); + int columnWithinPerAccess = c % kElementsPerAccess; + mask_.predicates[c] = ((thread_offset.column() + ThreadMap::Delta::kColumn * columnPerAccess + + columnWithinPerAccess) < extent.column()); + } + + if (threadblock_offset.column() == 0) { + EpilogueOpParams const& user_params = params_.user_param; + shared_storage_.initSmem(user_params); + } + + // Null pointer performs no accesses + if (!pointer) { mask_.clear(); } + + if (ScatterD && !indices) { mask_.clear(); } + + // Initialize pointer + first_tile_byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(block_offset.row()) * LongIndex(params_.stride); + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Destructor + CUTLASS_DEVICE + ~PredicatedTileIteratorReducedVec() + { + if (do_gmem_reduction_) { + EpilogueOpParams const& user_params = params_.user_param; + auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); + Element* shared_elem_arr = shared_storage_.data(); + const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); + bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); + // If this is not optimal grid size perform mutex based gmem reduce. + if (useGmemMutex) { + // single lock per block for multiple rows + if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { + // acquire mutex lock. + unsigned int ns = 8; + while (atomicCAS(user_params.mutexes_ + mutex_id, 0, 1) == 1) { + __nanosleep(ns); + if (ns < 256) { ns *= 2; } + } + } + } + + __syncthreads(); + for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { + if (block_start_row_first_tile_ + row < extent_row_) { + user_params.red_op_( + block_start_row_first_tile_ + row, &gmem_ptr[row], shared_elem_arr[row]); + } + } + + if (useGmemMutex) { + __threadfence(); + __syncthreads(); + if (threadIdx.x == 0 && block_start_row_first_tile_ < extent_row_) { + // release mutex lock. + atomicExch(user_params.mutexes_ + mutex_id, 0); + } + } + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Performs reduction and Stores a reduced output to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const + { + AccessTypeValT* frag_ptr = reinterpret_cast(&frag); + + cg::thread_block cta = cg::this_thread_block(); + // tile_width 16 is required if kElementPerAccess > 1 + constexpr int tile_width = (32 / ThreadMap::Delta::kColumn) ? 32 : 16; + cg::thread_block_tile tile32 = cg::tiled_partition(cta); + EpilogueOpParams const& user_params = params_.user_param; + + using cg_reduce_t = decltype(user_params.cg_reduce_op); + using tile32_t = decltype(tile32); + + Element* shared_elem_arr = shared_storage_.data(); + constexpr auto maxVal = std::numeric_limits::max(); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + const OutIdxT row_id = row_offset + thread_start_row_; + bool row_guard = (row_id < extent_row_); + + const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn * kElementsPerAccess; + Element red_val; + user_params.red_op_.init(&red_val, maxVal); + + if (row_guard) { + const int iter_row = (row_id % total_rows); + const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; + ++column) { + int columnPerAccess = column / kElementsPerAccess; + int columnWithPerAccess = column % kElementsPerAccess; + bool guard = mask_.predicates[column]; + if (guard) { + const OutIdxT key_id = thread_start_column_ + + ThreadMap::Delta::kColumn * columnPerAccess + + columnWithPerAccess; + const int frag_col_idx = frag_idx + column; + + Element this_val; + user_params.red_op_.init(&this_val, (*frag_ptr)[frag_col_idx]); + user_params.red_op_.init_key(this_val, key_id); + user_params.red_op_(row_id, &red_val, this_val); + } + } + // select_reduce doesn't need to use `red_op_` as at the warp level we use cg_reduce_op, + // this satisfies the requirement of mst/single linkage of checking colors buffer. + select_reduce red_obj( + red_val, prev_red_val, user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); + } + } + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment& frag) const { store_with_byte_offset(frag, 0); } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { return thread_start_row_; } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { return thread_start_column_; } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { return extent_row_; } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { return extent_column_; } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorReducedVec& operator++() + { + ++state_[0]; + + if (!ScatterD) { byte_pointer_ += params_.advance_row; } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { mask_.clear(); } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { mask_.enable(); } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index be6fed9f10..2ff8fa7f1c 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -21,8 +21,10 @@ #include // raft::KeyValuePair #include // raft::identity_op #include // ops::l2_exp_distance_op +#include #include // PairwiseDistances #include // Policy +#include // raft::util::arch::SM_* #include // raft::ceildiv, raft::shfl namespace raft { @@ -41,7 +43,7 @@ struct KVPMinReduceImpl { template struct MinAndDistanceReduceOpImpl { typedef typename raft::KeyValuePair KVP; - DI void operator()(LabelT rid, KVP* out, const KVP& other) + DI void operator()(LabelT rid, KVP* out, const KVP& other) const { if (other.value < out->value) { out->key = other.key; @@ -49,17 +51,28 @@ struct MinAndDistanceReduceOpImpl { } } - DI void operator()(LabelT rid, DataT* out, const KVP& other) + DI void operator()(LabelT rid, DataT* out, const KVP& other) const { if (other.value < *out) { *out = other.value; } } - DI void init(DataT* out, DataT maxVal) { *out = maxVal; } - DI void init(KVP* out, DataT maxVal) + DI void operator()(LabelT rid, DataT* out, const DataT& other) const { - out->key = 0; - out->value = maxVal; + if (other < *out) { *out = other; } } + + DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } + DI void init(KVP* out, DataT maxVal) const { out->value = maxVal; } + + DI void init_key(DataT& out, LabelT idx) const { return; } + DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } + + DI DataT get_value(KVP& out) const + { + return out.value; + ; + } + DI DataT get_value(DataT& out) const { return out; } }; template @@ -141,6 +154,8 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, OpT distance_op, FinalLambda fin_op) { +// compile only if below non-ampere arch. +#if __CUDA_ARCH__ < 800 extern __shared__ char smem[]; typedef KeyValuePair KVPair; @@ -236,8 +251,29 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, fin_op, rowEpilog_lambda); obj.run(); +#endif } +// cg::reduce functor for FusedDistanceNN used in its cutlass version +// to output the min distance value & key(loc id). +// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h +// store_with_byte_offset() passed to cg::reduce() & select_reduce. +template +struct kvp_cg_min_reduce_op { + typedef typename raft::KeyValuePair KVP; + + __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; + + using AccTypeT = AccType; + using IndexT = Index; + // functor signature. + __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } + + __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } + + __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } +}; + template distance_op{sqrt}; raft::identity_op fin_op{}; @@ -290,11 +325,58 @@ void fusedL2NNImpl(OutT* min, decltype(distance_op), decltype(fin_op)>; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); - - kernel<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); - RAFT_CUDA_TRY(cudaGetLastError()); + // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + void* kernel_ptr = reinterpret_cast(kernel); + auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + using L2Op = raft::distance::detail::ops::l2_exp_cutlass_op; + using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; + kvp_cg_min_reduce_op_ cg_reduce_op; + L2Op L2_dist_op(sqrt); + + IdxT lda, ldb, ldd; + lda = k, ldb = k, ldd = n; + + cutlassFusedDistanceNN(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + min, + workspace, + cg_reduce_op, + L2_dist_op, + redOp, + pairRedOp, + stream); + } else { + // If device less than SM_80, use fp32 SIMT kernel. + constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); + RAFT_CUDA_TRY(cudaGetLastError()); + } } } // namespace detail diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index efcd5d9389..ccb3bd46bf 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -18,6 +18,7 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wtautological-compare" // We define CUTLASS_NAMESPACE in case // RAFT cmake is not used @@ -38,20 +39,11 @@ #include #include +#include #include "./pairwise_distance_epilogue_elementwise.h" #include "./pairwise_distance_gemm.h" -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - if (error != cutlass::Status::kSuccess) { \ - std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ - << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } - namespace raft { namespace distance { namespace detail { @@ -164,14 +156,13 @@ std::enable_if_t::value> cutlassDistanceKernel(const Da // Instantiate CUTLASS kernel depending on templates cutlassDist cutlassDist_op; // Check the problem size is supported or not - cutlass::Status status = cutlassDist_op.can_implement(arguments); - CUTLASS_CHECK(status); + RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); + // Initialize CUTLASS kernel with arguments and workspace pointer - status = cutlassDist_op.initialize(arguments, workspace.data(), stream); - CUTLASS_CHECK(status); + RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); + // Launch initialized CUTLASS kernel - status = cutlassDist_op(); - CUTLASS_CHECK(status); + RAFT_CUTLASS_TRY(cutlassDist_op()); } }; // namespace detail diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h index ebe6d0c80a..cd748b9e6b 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h @@ -284,11 +284,15 @@ class PredicatedTileIteratorNormVec { CUTLASS_PRAGMA_UNROLL for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[0], - guard); + if (column == 0) { + cutlass::arch::global_load( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer[0], + guard); + } else { + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn]; + } } if (row + 1 < ThreadMap::Iterations::kRow) { diff --git a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh index adcb566cea..f089cbea83 100644 --- a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh @@ -61,10 +61,13 @@ struct FixConnectivitiesRedOp { value_idx* colors; value_idx m; + // default constructor for cutlass + DI FixConnectivitiesRedOp() : colors(0), m(0) {} + FixConnectivitiesRedOp(value_idx* colors_, value_idx m_) : colors(colors_), m(m_){}; typedef typename raft::KeyValuePair KVP; - DI void operator()(value_idx rit, KVP* out, const KVP& other) + DI void operator()(value_idx rit, KVP* out, const KVP& other) const { if (rit < m && other.value < out->value && colors[rit] != colors[other.key]) { out->key = other.key; @@ -72,9 +75,7 @@ struct FixConnectivitiesRedOp { } } - DI KVP - - operator()(value_idx rit, const KVP& a, const KVP& b) + DI KVP operator()(value_idx rit, const KVP& a, const KVP& b) const { if (rit < m && a.value < b.value && colors[rit] != colors[a.key]) { return a; @@ -82,12 +83,19 @@ struct FixConnectivitiesRedOp { return b; } - DI void init(value_t* out, value_t maxVal) { *out = maxVal; } - DI void init(KVP* out, value_t maxVal) + DI void init(value_t* out, value_t maxVal) const { *out = maxVal; } + DI void init(KVP* out, value_t maxVal) const { out->key = -1; out->value = maxVal; } + + DI void init_key(value_t& out, value_idx idx) const { return; } + DI void init_key(KVP& out, value_idx idx) const { out.key = idx; } + + DI value_t get_value(KVP& out) const { return out.value; } + + DI value_t get_value(value_t& out) const { return out; } }; /** diff --git a/cpp/include/raft/util/cutlass_utils.cuh b/cpp/include/raft/util/cutlass_utils.cuh new file mode 100644 index 0000000000..da402c9427 --- /dev/null +++ b/cpp/include/raft/util/cutlass_utils.cuh @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft { + +/** + * @brief Exception thrown when a CUTLASS error is encountered. + */ +struct cutlass_error : public raft::exception { + explicit cutlass_error(char const* const message) : raft::exception(message) {} + explicit cutlass_error(std::string const& message) : raft::exception(message) {} +}; + +} // namespace raft + +/** + * @brief Error checking macro for CUTLASS functions. + * + * Invokes a CUTLASS function call, if the call does not return cutlass::Status::kSuccess, + * throws an exception detailing the CUTLASS error that occurred. + * + */ +#define RAFT_CUTLASS_TRY(call) \ + do { \ + cutlass::Status const status = call; \ + if (status != cutlass::Status::kSuccess) { \ + std::string msg{}; \ + SET_ERROR_MSG(msg, \ + "CUTLASS error encountered at: ", \ + "call='%s', Reason=%s", \ + #call, \ + cutlassGetStatusString(status)); \ + throw raft::cutlass_error(msg); \ + } \ + } while (0) diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index e3f3bf3324..e807256f67 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -276,6 +276,8 @@ const std::vector> inputsf = { {0.001f, 128, 128, 65, 1234ULL}, {0.001f, 64, 128, 129, 1234ULL}, {0.006f, 1805, 134, 2, 1234ULL}, + {0.006f, 8192, 1024, 64, 1234ULL}, + {0.006f, 8192, 1025, 64, 1234ULL}, // Repeat with smaller values of k {0.006f, 32, 32, 1, 1234ULL}, @@ -305,6 +307,8 @@ const std::vector> inputsf = { {0.001f, 128, 128, 23, 1234ULL}, {0.00001, 64, 128, 24, 1234ULL}, {0.001f, 1805, 134, 25, 1234ULL}, + {0.006f, 8192, 1024, 25, 1234ULL}, + {0.006f, 8192, 1024, 66, 1234ULL}, }; typedef FusedL2NNTest FusedL2NNTestF_Sq; TEST_P(FusedL2NNTestF_Sq, Result) @@ -339,7 +343,7 @@ const std::vector> inputsd = { {0.00001, 128, 32, 33, 1234ULL}, {0.00001, 128, 64, 33, 1234ULL}, {0.00001, 128, 128, 65, 1234ULL}, {0.00001, 64, 128, 129, 1234ULL}, - {0.00001, 1805, 134, 2, 1234ULL}, + {0.00001, 1805, 134, 2, 1234ULL}, {0.00001, 8192, 1024, 25, 1234ULL}, }; typedef FusedL2NNTest FusedL2NNTestD_Sq; TEST_P(FusedL2NNTestD_Sq, Result)