Skip to content

Commit

Permalink
relax swizzle.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Jan 23, 2025
1 parent 29ca160 commit 4292eb4
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 322 deletions.
173 changes: 68 additions & 105 deletions include/cell/copy/global_to_shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ namespace tl = tile_layout;
/**
* @brief Load a warp tile from global memory to shared memory.
*
* This function loads a warp tile whose shape is specified by `BaseShape`
* from global memory to shared memory.
* This function loads a data tile from global to shared memory.
*
* @tparam Global_ The type of the global memory pointer.
* @tparam Shared_ The type of the shared memory pointer.
* @tparam BaseShape_ The shape of the warp tile.
* @tparam kRowExec_ The number of rows to execute.
* @tparam kColExec_ The number of columns to execute.
* @tparam kType The type of the elements to be loaded.
* @tparam Global The type of the global memory tile.
* @tparam Shared The type of the shared memory tile.
* @tparam BaseShape The shape of the base tile.
* @tparam kRowExec The number of rows to execute.
* @tparam kColExec The number of columns to execute.
* @tparam kType The type of Global and Shared memory layout.
*/
template <typename Global, typename Shared, typename BaseShape,
const int kRowExec, const int kColExec,
Expand Down Expand Up @@ -54,34 +53,26 @@ struct GlobalToSharedLoaderImpl<Global_, Shared_, BaseShape_, kRowExec_,
static constexpr int kColExec = kColExec_;

DEVICE void operator()(const DType* src, DType* dst) {
// TODO(KuangjuX): When the `WarpRow` is greater than 1, a swizzle block
// might be split by two warps, and a solution is needed to address this
// situation.
int row = lane_row_id();
int col = lane_col_id() * kNumPerAccess;

/// the pointer offset inside a warp tile.
int src_lane_offset = src_in_base_tile_(row, col); // global
int dst_lane_offset = dst_in_base_tile_(row, col); // shared

int src_offset = 0, dst_offset = 0;
uint32_t dst_ptr;
#pragma unroll
for (int i = 0; i < kRowExec; ++i) {
#pragma unroll
for (int j = 0; j < kColExec; ++j) {
int tile_i =
(i * BaseShape::kRows + row) / SwizzledBaseShape::kRows;
int tile_j =
(j * BaseShape::kCols + col) / SwizzledBaseShape::kCols;
int tile_row =
(i * BaseShape::kRows + row) % SwizzledBaseShape::kRows;
int tile_col =
(j * BaseShape::kCols + col) % SwizzledBaseShape::kCols;

/// the pointer offset inside a warp tile.
int src_lane_offset = src_tile_(row, col);
int dst_tile_offset = dst_tile_(tile_row, tile_col);

src_offset = src_base_tiles_(i, j) + src_lane_offset;
dst_offset = dst_base_tiles_(tile_i, tile_j) + dst_tile_offset;
dst_offset = dst_base_tiles_(i, j) + dst_lane_offset;

dst_ptr = static_cast<uint32_t>(
__cvta_generic_to_shared(dst + dst_offset));

copy(src + src_offset, dst + dst_offset);
ld_global_st_shared<kAccessInBytes>(dst_ptr, src + src_offset);
}
}
}
Expand All @@ -90,64 +81,49 @@ struct GlobalToSharedLoaderImpl<Global_, Shared_, BaseShape_, kRowExec_,
static constexpr int kNumPerAccess =
traits::AccessBase<DType>::kNumPerAccess;

using SwizzledBaseShape = traits::SwizzleBaseTileShape<DType>;
static constexpr int kSwizzledRows = SwizzledBaseShape::kRows;
static constexpr int kSwizzledCols = SwizzledBaseShape::kCols;
static constexpr int B = SwizzledBaseShape::B;
static constexpr int M = SwizzledBaseShape::M;
static constexpr int S = SwizzledBaseShape::S;
static constexpr int kAccessInBytes =
traits::AccessBase<DType>::kAccessInBytes;

static constexpr int kSwizzledRowExec =
kRowExec / (kSwizzledRows / BaseShape::kRows);
static constexpr int kSwizzledColExec =
kColExec / (kSwizzledCols / BaseShape::kCols);

using SrcBaseTilesLayout =
using SrcBaseTilesLayout = // global
tl::MatrixLayout<kRowExec, kColExec,
BaseShape::kRows * Global::kRowStride,
BaseShape::kCols>;
SrcBaseTilesLayout src_base_tiles_;

using DstSwizzledLayout =
tl::MatrixLayout<kSwizzledRowExec, kSwizzledColExec,
kSwizzledRows * Shared::kRowStride, kSwizzledCols>;
DstSwizzledLayout dst_base_tiles_;
using DstBaseTilesLayout = // shared
tl::MatrixLayout<kRowExec, kColExec,
BaseShape::kRows * Shared::kRowStride,
BaseShape::kNumel>;
DstBaseTilesLayout dst_base_tiles_;

// Given a thread index, the GlobalLayout and SharedLayout below return the
// data offset from which the thread should load from the global memory tile
// and where to store it in the shared memory tile, respectively.
using GlobalLayout = tl::MatrixLayout<BaseShape::kRows, BaseShape::kCols,
Global::kRowStride, 1>;
GlobalLayout src_in_base_tile_;

// `src_tile_` is a basetile handled by a single warp.
GlobalLayout src_tile_;

using NonSwizzled =
tl::MatrixLayout<kSwizzledRows, kSwizzledCols, Shared::kRowStride, 1>;
using Swizzled = SwizzledLayout<NonSwizzled, B, M, S>;
using NonSwizzled = tl::RowMajor<BaseShape::kRows, BaseShape::kCols>;
using Swizzled = SwizzledLayout<NonSwizzled, 3, 3, 3>;

using SharedLayout =
std::conditional_t<Shared::kSwizzled, Swizzled, NonSwizzled>;
SharedLayout dst_tile_;
SharedLayout dst_in_base_tile_;

DEVICE void copy(const DType* src, DType* dst) {
// a single memory access access 16 bytes
ld_global_st_shared<16>(
static_cast<uint32_t>(__cvta_generic_to_shared(dst)), src);
}

/// @brief returns the lane row of the current thread within a warp.
/**
* @brief Returns the row index of the current thread within a warp.
*/
DEVICE int lane_row_id() {
// NOTE: When copying a RowMajor data tile, the thread layout is
// interpreted as RowMajor.
// NOTE: When loading a RowMajor data tile, the threads in a warp are
// interpreted as being arranged in a row-major fashion.
int lane_id = threadIdx.x % WARP_SIZE;
return lane_id / BaseShape::kColThreads;
}

/// @brief returns the lane col of the current thread within a warp.
/// @brief Returns the column index of the current thread within a warp.
DEVICE int lane_col_id() {
// NOTE: When copying a RowMajor data tile, the thread layout is
// interpreted as RowMajor.
// NOTE: When loading a RowMajor data tile, the threads in a warp are
// interpreted as being arranged in a row-major fashion.
int lane_id = threadIdx.x % WARP_SIZE;
return lane_id % BaseShape::kColThreads;
}
Expand Down Expand Up @@ -257,48 +233,40 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,
int row = lane_row_id();
int col = lane_col_id() * kNumPerAccess;

/// the pointer offset inside a warp tile.
int src_lane_offset = src_tile_(row, col);
int dst_lane_offset = dst_tile_(row, col);

int src_offset = 0, dst_offset = 0;
uint32_t src_ptr;
#pragma unroll
for (int i = 0; i < kRowExec; ++i) {
#pragma unroll
for (int j = 0; j < kColExec; ++j) {
int tile_i =
(i * BaseShape::kRows + row) / SwizzledBaseShape::kRows;
int tile_j =
(j * BaseShape::kCols + col) / SwizzledBaseShape::kCols;
int tile_row =
(i * BaseShape::kRows + row) % SwizzledBaseShape::kRows;
int tile_col =
(j * BaseShape::kCols + col) % SwizzledBaseShape::kCols;

int src_tile_offset = src_tile_(tile_row, tile_col);
int dst_lane_offset = dst_tile_(row, col);

src_offset = src_base_tiles_(tile_i, tile_j) + src_tile_offset;
src_offset = src_base_tiles_(i, j) + src_lane_offset;
dst_offset = dst_base_tiles_(i, j) + dst_lane_offset;

copy(src + src_offset, dst + dst_offset);
src_ptr = static_cast<uint32_t>(
__cvta_generic_to_shared(src + src_offset));

ld_shared_st_global<kAccessInBytes>(dst + dst_offset, src_ptr);
}
}
}

private:
using SwizzledBaseShape = traits::SwizzleBaseTileShape<DType>;
static constexpr int kSwizzledRows = SwizzledBaseShape::kRows;
static constexpr int kSwizzledCols = SwizzledBaseShape::kCols;
static constexpr int B = SwizzledBaseShape::B;
static constexpr int M = SwizzledBaseShape::M;
static constexpr int S = SwizzledBaseShape::S;

static constexpr int kSwizzledRowExec =
kRowExec / (kSwizzledRows / BaseShape::kRows);
static constexpr int kSwizzledColExec =
kColExec / (kSwizzledCols / BaseShape::kCols);

using SrcSwizzledLayout =
tl::MatrixLayout<kSwizzledRowExec, kSwizzledColExec,
kSwizzledRows * Shared::kRowStride, kSwizzledCols>;
SrcSwizzledLayout src_base_tiles_;
static constexpr int kNumPerAccess =
traits::AccessBase<DType>::kNumPerAccess;

static constexpr int kAccessInBytes =
traits::AccessBase<DType>::kAccessInBytes;

// a SharedTile is contiguously stored
using SrcBaseTilesLayout =
tl::MatrixLayout<kRowExec, kColExec,
BaseShape::kRows * Shared::kRowStride,
BaseShape::kNumel>;
SrcBaseTilesLayout src_base_tiles_;

using DstBaseTilesLayout =
tl::MatrixLayout<kRowExec, kColExec,
Expand All @@ -312,12 +280,10 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,
// consistent with those used in `SharedLayoutWrapper` within the
// register-to-shared storer.
static constexpr int kAccessInBits = 2 * int(sizeof(DType) * 8);
static constexpr int kNumPerAccess =
traits::AccessBase<DType>::kNumPerAccess;

using NonSwizzled =
tl::MatrixLayout<kSwizzledRows, kSwizzledCols, Shared::kRowStride, 1>;
using Swizzled = SwizzledLayout<NonSwizzled, B, M, S>;
using NonSwizzled = tl::RowMajor<BaseShape::kRows, BaseShape::kCols>;
using Swizzled = SwizzledLayout<NonSwizzled, 3, 3, 3>;

using SharedLayout =
std::conditional_t<Shared::kSwizzled, Swizzled, NonSwizzled>;
SharedLayout src_tile_;
Expand All @@ -335,11 +301,6 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,
DEVICE int lane_col_id() {
return (threadIdx.x % WARP_SIZE) % BaseShape::kColThreads;
}

DEVICE void copy(const DType* src, DType* dst) {
ld_shared_st_global<16>(
dst, static_cast<uint32_t>(__cvta_generic_to_shared(src)));
}
};

template <typename Shared_, typename Global_, typename BaseShape_,
Expand Down Expand Up @@ -399,8 +360,9 @@ struct GlobalToSharedLoader {
using DType = Shared::DType;
using WarpLayout = WarpLayout_;

using BaseShape =
warp::WarpBaseTileShape<DType, typename Shared::Layout, Shared::kType>;
using WarpShape = TileShape<Shared::kRows / WarpLayout::kRows,
Shared::kCols / WarpLayout::kCols>;
using BaseShape = warp::WarpBaseTileShape<DType, WarpShape, Shared::kType>;

static_assert(Shared::kRows % BaseShape ::kRows == 0,
"Shared::kRows must be divisible by BaseShape::kRows.");
Expand Down Expand Up @@ -452,8 +414,9 @@ struct SharedToGlobalStorer {
using DType = Shared::DType;
using WarpLayout = WarpLayout_;

using BaseShape =
warp::WarpBaseTileShape<DType, typename Shared::Layout, Shared::kType>;
using WarpShape = TileShape<Shared::kRows / WarpLayout::kRows,
Shared::kCols / WarpLayout::kCols>;
using BaseShape = warp::WarpBaseTileShape<DType, WarpShape, Shared::kType>;

static_assert(Shared::kRows % BaseShape::kRows == 0,
"Shared::kRows must be divisible by BaseShape::kRows.");
Expand Down
Loading

0 comments on commit 4292eb4

Please sign in to comment.