Skip to content

Commit

Permalink
group gemm set stride L = cute::Int<0> (#1416)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xseventh authored Mar 20, 2024
1 parent 629f465 commit c4e3e12
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
8 changes: 4 additions & 4 deletions include/cutlass/detail/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,31 +87,31 @@ struct TagToStrideB<layout::ColumnMajor> {
// Maps to modes [M, K, L]
template <>
struct TagToStrideA<layout::RowMajor *> {
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
using type = UnderlyingType*;
using tag = layout::RowMajor;
};

// Maps to modes [M, K, L]
template <>
struct TagToStrideA<layout::ColumnMajor *> {
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, cute::Int<0>>;
using type = UnderlyingType*;
using tag = layout::ColumnMajor;
};

// Maps to modes [N, K, L]
template <>
struct TagToStrideB<layout::RowMajor *> {
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, cute::Int<0>>;
using type = UnderlyingType*;
using tag = layout::RowMajor;
};

// Maps to modes [N, K, L]
template <>
struct TagToStrideB<layout::ColumnMajor *> {
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
using type = UnderlyingType*;
using tag = layout::ColumnMajor;
};
Expand Down
24 changes: 24 additions & 0 deletions tools/util/include/cutlass/util/packed_stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,30 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape

/////////////////////////////////////////////////////////////////////////////////////////////////

// Strides with group mode

template <class StrideIntT>
cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>>
make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<0>(s_copy) = static_cast<StrideIntT>(cute::get<1>(shape_MKL));
return s_copy;
}

template <class StrideIntT>
cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>>
make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<1>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL));
return s_copy;
}

/////////////////////////////////////////////////////////////////////////////////////////////////

// Strides for convolutions

// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0)
Expand Down

0 comments on commit c4e3e12

Please sign in to comment.