Skip to content

Commit

Permalink
macro cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
unixpickle committed Mar 21, 2024
1 parent 4e57346 commit f47c2f1
Showing 1 changed file with 22 additions and 41 deletions.
63 changes: 22 additions & 41 deletions accelerated_scan/warp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -222,54 +222,35 @@ __global__ void scan(
}
}

#define DISPATCH_SCAN_INNER(TupleT, backward, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse) \
scan<TupleT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, backward><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
reinterpret_cast<const TupleT *>(gates.data_ptr<torch_weight_t>()), \
reinterpret_cast<const TupleT *>(tokens.data_ptr<torch_weight_t>()), \
reinterpret_cast<TupleT *>(out.data_ptr<torch_weight_t>()), \
reinterpret_cast<const TupleT *>(output), \
reinterpret_cast<TupleT *>(gateGradOut), \
batch_stride, dim_stride, reverse \
);

#define DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse) \
using AlignedT = AlignedTuple<weight_t, kNStepsPerThread>; \
using UnalignedT = UnalignedTuple<weight_t, kNStepsPerThread>; \
if (!output) { \
if (kNStepsPerThread == 4 && \
((long)gates.data_ptr()) % 16 == 0 && \
((long)tokens.data_ptr()) % 16 == 0 && \
((long)out.data_ptr()) % 16 == 0) { \
scan<AlignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, false><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
reinterpret_cast<const AlignedT *>(gates.data_ptr<torch_weight_t>()), \
reinterpret_cast<const AlignedT *>(tokens.data_ptr<torch_weight_t>()), \
reinterpret_cast<AlignedT *>(out.data_ptr<torch_weight_t>()), \
nullptr, nullptr, \
batch_stride, dim_stride, reverse \
); \
if (kNStepsPerThread == 4 && \
((long)gates.data_ptr()) % 16 == 0 && \
((long)tokens.data_ptr()) % 16 == 0 && \
((long)out.data_ptr()) % 16 == 0 && \
((long)output) % 16 == 0 && \
((long)gateGradOut) % 16 == 0) { \
if (output) { \
DISPATCH_SCAN_INNER(AlignedT, true, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
} else { \
scan<UnalignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, false><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
reinterpret_cast<const UnalignedT*>(gates.data_ptr<torch_weight_t>()), \
reinterpret_cast<const UnalignedT*>(tokens.data_ptr<torch_weight_t>()), \
reinterpret_cast<UnalignedT *>(out.data_ptr<torch_weight_t>()), \
nullptr, nullptr, \
batch_stride, dim_stride, reverse \
); \
DISPATCH_SCAN_INNER(AlignedT, false, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
} \
} else { \
if (kNStepsPerThread == 4 && \
((long)gates.data_ptr()) % 16 == 0 && \
((long)tokens.data_ptr()) % 16 == 0 && \
((long)out.data_ptr()) % 16 == 0 && \
((long)output) % 16 == 0 && \
((long)gateGradOut) % 16 == 0) { \
scan<AlignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, true><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
reinterpret_cast<const AlignedT *>(gates.data_ptr<torch_weight_t>()), \
reinterpret_cast<const AlignedT *>(tokens.data_ptr<torch_weight_t>()), \
reinterpret_cast<AlignedT *>(out.data_ptr<torch_weight_t>()), \
reinterpret_cast<const AlignedT *>(output), \
reinterpret_cast<AlignedT *>(gateGradOut), \
batch_stride, dim_stride, reverse \
); \
if (output) { \
DISPATCH_SCAN_INNER(UnalignedT, true, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
} else { \
scan<UnalignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, true><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
reinterpret_cast<const UnalignedT*>(gates.data_ptr<torch_weight_t>()), \
reinterpret_cast<const UnalignedT*>(tokens.data_ptr<torch_weight_t>()), \
reinterpret_cast<UnalignedT *>(out.data_ptr<torch_weight_t>()), \
reinterpret_cast<const UnalignedT *>(output), \
reinterpret_cast<UnalignedT *>(gateGradOut), \
batch_stride, dim_stride, reverse \
); \
DISPATCH_SCAN_INNER(UnalignedT, false, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
} \
}

Expand Down

0 comments on commit f47c2f1

Please sign in to comment.