From f47c2f1c9a6995f71d2b8954993f1fbc255f5e61 Mon Sep 17 00:00:00 2001 From: Alex Nichol Date: Wed, 20 Mar 2024 21:14:38 -0400 Subject: [PATCH] macro cleanup --- accelerated_scan/warp.cuh | 63 ++++++++++++++------------------------- 1 file changed, 22 insertions(+), 41 deletions(-) diff --git a/accelerated_scan/warp.cuh b/accelerated_scan/warp.cuh index 6a70f35..811df27 100644 --- a/accelerated_scan/warp.cuh +++ b/accelerated_scan/warp.cuh @@ -222,54 +222,35 @@ __global__ void scan( } } +#define DISPATCH_SCAN_INNER(TupleT, backward, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse) \ + scan<<>>( \ + reinterpret_cast(gates.data_ptr()), \ + reinterpret_cast(tokens.data_ptr()), \ + reinterpret_cast(out.data_ptr()), \ + reinterpret_cast(output), \ + reinterpret_cast(gateGradOut), \ + batch_stride, dim_stride, reverse \ + ); + #define DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse) \ using AlignedT = AlignedTuple; \ using UnalignedT = UnalignedTuple; \ - if (!output) { \ - if (kNStepsPerThread == 4 && \ - ((long)gates.data_ptr()) % 16 == 0 && \ - ((long)tokens.data_ptr()) % 16 == 0 && \ - ((long)out.data_ptr()) % 16 == 0) { \ - scan<<>>( \ - reinterpret_cast(gates.data_ptr()), \ - reinterpret_cast(tokens.data_ptr()), \ - reinterpret_cast(out.data_ptr()), \ - nullptr, nullptr, \ - batch_stride, dim_stride, reverse \ - ); \ + if (kNStepsPerThread == 4 && \ + ((long)gates.data_ptr()) % 16 == 0 && \ + ((long)tokens.data_ptr()) % 16 == 0 && \ + ((long)out.data_ptr()) % 16 == 0 && \ + ((long)output) % 16 == 0 && \ + ((long)gateGradOut) % 16 == 0) { \ + if (output) { \ + DISPATCH_SCAN_INNER(AlignedT, true, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \ } else { \ - scan<<>>( \ - reinterpret_cast(gates.data_ptr()), \ - reinterpret_cast(tokens.data_ptr()), \ - reinterpret_cast(out.data_ptr()), \ - nullptr, nullptr, \ - batch_stride, dim_stride, reverse \ - ); \ + DISPATCH_SCAN_INNER(AlignedT, false, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \ } \ } else { \ - if (kNStepsPerThread == 4 && \ - ((long)gates.data_ptr()) % 16 == 0 && \ - ((long)tokens.data_ptr()) % 16 == 0 && \ - ((long)out.data_ptr()) % 16 == 0 && \ - ((long)output) % 16 == 0 && \ - ((long)gateGradOut) % 16 == 0) { \ - scan<<>>( \ - reinterpret_cast(gates.data_ptr()), \ - reinterpret_cast(tokens.data_ptr()), \ - reinterpret_cast(out.data_ptr()), \ - reinterpret_cast(output), \ - reinterpret_cast(gateGradOut), \ - batch_stride, dim_stride, reverse \ - ); \ + if (output) { \ + DISPATCH_SCAN_INNER(UnalignedT, true, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \ } else { \ - scan<<>>( \ - reinterpret_cast(gates.data_ptr()), \ - reinterpret_cast(tokens.data_ptr()), \ - reinterpret_cast(out.data_ptr()), \ - reinterpret_cast(output), \ - reinterpret_cast(gateGradOut), \ - batch_stride, dim_stride, reverse \ - ); \ + DISPATCH_SCAN_INNER(UnalignedT, false, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \ } \ }