Skip to content

Commit

Permalink
Parallelize kernel compilation in FAISS (facebookresearch#2922)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#2922

This parallelizes kernel compilation by taking a template function from much deeper in the stack than was previously the case and generating 128 compilation units rather than the original 8.

Reviewed By: mdouze

Differential Revision: D46674315

fbshipit-source-id: 830eeaf43dee2c081f735be47c809b28aa3a05f6
  • Loading branch information
r-barnes authored and Thejas-bhat committed Sep 26, 2023
1 parent 2e87779 commit c98c19e
Show file tree
Hide file tree
Showing 13 changed files with 179 additions and 279 deletions.
76 changes: 68 additions & 8 deletions faiss/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,6 @@ set(FAISS_GPU_SRC
impl/PQScanMultiPassPrecomputed.cu
impl/RemapIndices.cpp
impl/VectorResidual.cu
impl/scan/IVFInterleaved1.cu
impl/scan/IVFInterleaved32.cu
impl/scan/IVFInterleaved64.cu
impl/scan/IVFInterleaved128.cu
impl/scan/IVFInterleaved256.cu
impl/scan/IVFInterleaved512.cu
impl/scan/IVFInterleaved1024.cu
impl/scan/IVFInterleaved2048.cu
impl/IcmEncoder.cu
utils/BlockSelectFloat.cu
utils/DeviceUtils.cu
Expand Down Expand Up @@ -176,6 +168,74 @@ set(FAISS_GPU_HEADERS
utils/warpselect/WarpSelectImpl.cuh
)

function(generate_ivf_interleaved_code)
set(SUB_CODEC_TYPE
"faiss::gpu::Codec<0, 1>"
"faiss::gpu::Codec<1, 1>"
"faiss::gpu::Codec<2, 1>"
"faiss::gpu::Codec<3, 1>"
"faiss::gpu::Codec<4, 1>"
"faiss::gpu::Codec<5, 1>"
"faiss::gpu::Codec<6, 1>"
"faiss::gpu::CodecFloat"
)

set(SUB_METRIC_TYPE
"faiss::gpu::IPDistance"
"faiss::gpu::L2Distance"
)

# Used for SUB_THREADS, SUB_NUM_WARP_Q, SUB_NUM_THREAD_Q
set(THREADS_AND_WARPS
"128|1024|8"
"128|1|1"
"128|128|3"
"128|256|4"
"128|32|2"
"128|512|8"
"128|64|3"
"64|2048|8"
)

# Traverse through the Cartesian product of X and Y
foreach(sub_codec ${SUB_CODEC_TYPE})
foreach(metric_type ${SUB_METRIC_TYPE})
foreach(threads_and_warps_str ${THREADS_AND_WARPS})
string(REPLACE "|" ";" threads_and_warps ${threads_and_warps_str})
list(GET threads_and_warps 0 sub_threads)
list(GET threads_and_warps 1 sub_num_warp_q)
list(GET threads_and_warps 2 sub_num_thread_q)

# Define the output file name
set(filename "template_${sub_codec}_${metric_type}_${sub_threads}_${sub_num_warp_q}_${sub_num_thread_q}")
# Remove illegal characters from filename
string(REGEX REPLACE "[^A-Za-z0-9_]" "" filename ${filename})
set(output_file "${CMAKE_CURRENT_BINARY_DIR}/${filename}.cu")

# Read the template file
file(READ "${CMAKE_CURRENT_SOURCE_DIR}/impl/scan/IVFInterleavedScanKernelTemplate.cu" template_content)

# Replace the placeholders
string(REPLACE "SUB_CODEC_TYPE" "${sub_codec}" template_content "${template_content}")
string(REPLACE "SUB_METRIC_TYPE" "${metric_type}" template_content "${template_content}")
string(REPLACE "SUB_THREADS" "${sub_threads}" template_content "${template_content}")
string(REPLACE "SUB_NUM_WARP_Q" "${sub_num_warp_q}" template_content "${template_content}")
string(REPLACE "SUB_NUM_THREAD_Q" "${sub_num_thread_q}" template_content "${template_content}")

# Write the modified content to the output file
file(WRITE "${output_file}" "${template_content}")

# Add the file to the sources
list(APPEND FAISS_GPU_SRC "${output_file}")
endforeach()
endforeach()
endforeach()
# Propagate modified variable to the parent scope
set(FAISS_GPU_SRC "${FAISS_GPU_SRC}" PARENT_SCOPE)
endfunction()

generate_ivf_interleaved_code()

if(FAISS_ENABLE_RAFT)
list(APPEND FAISS_GPU_HEADERS
impl/RaftFlatIndex.cuh)
Expand Down
18 changes: 8 additions & 10 deletions faiss/gpu/impl/IVFInterleaved.cu
Original file line number Diff line number Diff line change
Expand Up @@ -210,25 +210,23 @@ void runIVFInterleavedScan(
};

if (k == 1) {
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_1_PARAMS>);
ivf_interleaved_call(ivfInterleavedScanImpl<128, 1, 1>);
} else if (k <= 32) {
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_32_PARAMS>);
ivf_interleaved_call(ivfInterleavedScanImpl<128, 32, 2>);
} else if (k <= 64) {
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_64_PARAMS>);
ivf_interleaved_call(ivfInterleavedScanImpl<128, 64, 3>);
} else if (k <= 128) {
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_128_PARAMS>);
ivf_interleaved_call(ivfInterleavedScanImpl<128, 128, 3>);
} else if (k <= 256) {
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_256_PARAMS>);
ivf_interleaved_call(ivfInterleavedScanImpl<128, 256, 4>);
} else if (k <= 512) {
ivf_interleaved_call(ivfInterleavedScanImpl<IVFINTERLEAVED_512_PARAMS>);
ivf_interleaved_call(ivfInterleavedScanImpl<128, 512, 8>);
} else if (k <= 1024) {
ivf_interleaved_call(
ivfInterleavedScanImpl<IVFINTERLEAVED_1024_PARAMS>);
ivf_interleaved_call(ivfInterleavedScanImpl<128, 1024, 8>);
}
#if GPU_MAX_SELECTION_K >= 2048
else if (k <= 2048) {
ivf_interleaved_call(
ivfInterleavedScanImpl<IVFINTERLEAVED_2048_PARAMS>);
ivf_interleaved_call(ivfInterleavedScanImpl<64, 2048, 8>);
}
#endif
}
Expand Down
18 changes: 9 additions & 9 deletions faiss/gpu/impl/IVFInterleaved.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ template <
typename Metric,
int ThreadsPerBlock,
int NumWarpQ,
int NumThreadQ,
bool Residual>
int NumThreadQ>
__global__ void ivfInterleavedScan(
Tensor<float, 2, true> queries,
Tensor<float, 3, true> residualBase,
Expand All @@ -48,7 +47,8 @@ __global__ void ivfInterleavedScan(
int k,
// [query][probe][k]
Tensor<float, 3, true> distanceOut,
Tensor<idx_t, 3, true> indicesOut) {
Tensor<idx_t, 3, true> indicesOut,
const bool Residual) {
extern __shared__ float smem[];

constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
Expand Down Expand Up @@ -124,7 +124,7 @@ __global__ void ivfInterleavedScan(
for (int dBase = 0; dBase < dimBlocks; dBase += kWarpSize) {
const int loadDim = dBase + laneId;
const float queryReg = query[loadDim];
[[maybe_unused]] const float residualReg =
const float residualReg =
Residual ? residualBaseSlice[loadDim] : 0;

constexpr int kUnroll = 4;
Expand Down Expand Up @@ -152,7 +152,7 @@ __global__ void ivfInterleavedScan(
decV[j] = codec.decodeNew(dBase + d, encV[j]);
}

if constexpr (Residual) {
if (Residual) {
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
int d = i * kUnroll + j;
Expand All @@ -174,9 +174,9 @@ __global__ void ivfInterleavedScan(
const bool loadDimInBounds = loadDim < dim;

const float queryReg = loadDimInBounds ? query[loadDim] : 0;
[[maybe_unused]] const float residualReg =
Residual && loadDimInBounds ? residualBaseSlice[loadDim]
: 0;
const float residualReg = Residual && loadDimInBounds
? residualBaseSlice[loadDim]
: 0;

for (int d = 0; d < dim - dimBlocks;
++d, data += wordsPerVectorBlockDim) {
Expand All @@ -187,7 +187,7 @@ __global__ void ivfInterleavedScan(
enc = WarpPackedBits<EncodeT, Codec::kEncodeBits>::postRead(
laneId, enc);
float dec = codec.decodeNew(dimBlocks + d, enc);
if constexpr (Residual) {
if (Residual) {
dec += SHFL_SYNC(residualReg, d, kWarpSize);
}

Expand Down
16 changes: 0 additions & 16 deletions faiss/gpu/impl/scan/IVFInterleaved1.cu

This file was deleted.

16 changes: 0 additions & 16 deletions faiss/gpu/impl/scan/IVFInterleaved1024.cu

This file was deleted.

16 changes: 0 additions & 16 deletions faiss/gpu/impl/scan/IVFInterleaved128.cu

This file was deleted.

18 changes: 0 additions & 18 deletions faiss/gpu/impl/scan/IVFInterleaved2048.cu

This file was deleted.

16 changes: 0 additions & 16 deletions faiss/gpu/impl/scan/IVFInterleaved256.cu

This file was deleted.

16 changes: 0 additions & 16 deletions faiss/gpu/impl/scan/IVFInterleaved32.cu

This file was deleted.

16 changes: 0 additions & 16 deletions faiss/gpu/impl/scan/IVFInterleaved512.cu

This file was deleted.

16 changes: 0 additions & 16 deletions faiss/gpu/impl/scan/IVFInterleaved64.cu

This file was deleted.

Loading

0 comments on commit c98c19e

Please sign in to comment.