diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index 49bad92dfa..1616544291 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -87,7 +87,7 @@ struct TagToStrideB { // Maps to modes [M, K, L] template <> struct TagToStrideA { - using UnderlyingType = cute::Stride, int64_t>; + using UnderlyingType = cute::Stride, cute::Int<0>>; using type = UnderlyingType*; using tag = layout::RowMajor; }; @@ -95,7 +95,7 @@ struct TagToStrideA { // Maps to modes [M, K, L] template <> struct TagToStrideA { - using UnderlyingType = cute::Stride, int64_t, int64_t>; + using UnderlyingType = cute::Stride, int64_t, cute::Int<0>>; using type = UnderlyingType*; using tag = layout::ColumnMajor; }; @@ -103,7 +103,7 @@ struct TagToStrideA { // Maps to modes [N, K, L] template <> struct TagToStrideB { - using UnderlyingType = cute::Stride, int64_t, int64_t>; + using UnderlyingType = cute::Stride, int64_t, cute::Int<0>>; using type = UnderlyingType*; using tag = layout::RowMajor; }; @@ -111,7 +111,7 @@ struct TagToStrideB { // Maps to modes [N, K, L] template <> struct TagToStrideB { - using UnderlyingType = cute::Stride, int64_t>; + using UnderlyingType = cute::Stride, cute::Int<0>>; using type = UnderlyingType*; using tag = layout::ColumnMajor; }; diff --git a/tools/util/include/cutlass/util/packed_stride.hpp b/tools/util/include/cutlass/util/packed_stride.hpp index 8973513f4e..a3ed56a703 100644 --- a/tools/util/include/cutlass/util/packed_stride.hpp +++ b/tools/util/include/cutlass/util/packed_stride.hpp @@ -108,6 +108,30 @@ make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape ///////////////////////////////////////////////////////////////////////////////////////////////// +// Strides with group mode + +template +cute::Stride, cute::Int<0>> +make_cute_packed_stride(cute::Stride, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +cute::Stride, StrideIntT, cute::Int<0>> +make_cute_packed_stride(cute::Stride, StrideIntT, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + // Strides for convolutions // Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0)