Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Hopper mixed precision gemm always worse than FP8 #1549

Open
divchenko opened this issue May 24, 2024 · 8 comments
Open

[QST] Hopper mixed precision gemm always worse than FP8 #1549

divchenko opened this issue May 24, 2024 · 8 comments

Comments

@divchenko
Copy link

I'm doing A 4 bit x B fp16 matmul w/ large A and small B. I expect it to beat fp8 matmul (it should be memory-bound).
In reality, it seems to be always worse.

Example:
Kernel code is here: https://gist.github.com/divchenko/9b02f40ae109e8dc8549afbde059d32e
it's called from python:

import torch
import cuscratch

g = 64
m = 3584 
n = 16
k = 8192

scale_k = (k + g - 1) // g

s = torch.ones((m, scale_k), dtype=torch.half, device="cuda")
a = torch.ones((m, (k + 1) // 2), dtype=torch.int8, device="cuda")
b = torch.ones((n, k), dtype=torch.half, device="cuda")
d = torch.zeros((n, m), dtype=torch.half, device="cuda")

cuscratch.matmul_mixed(a, b.t(), d.t(), s, k, g)

The best perf I can get is using streamk scheduler (k is large indeed). But it's still very low on memory b/w (~20%).
Persistent tile scheduler is way worse for both TMA and TMACooperative kernel schedulers.
Fp8 implementation can reach ~60% of memory b/w and hence is faster although it reads ~2x more bytes.

Am I missing anything? Thank you!

@thakkarV
Copy link
Collaborator

@jackkosaian @IonThruster

@IonThruster
Copy link
Collaborator

Could you share more nfo on what exact c++ kernel is being picked in both cases ? You may have to pick a custom tile size instead of what the builder provides by default. The default ones are more optimized for compute bound cases.

@divchenko
Copy link
Author

@IonThruster full code is here. I've played w/ tiles. This is the best config.

#include <ATen/ATen.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <pybind11/operators.h>
#include <torch/extension.h>

#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/epilogue/collective/default_epilogue.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>

namespace cuscratch {

namespace {

#define CUTLASS_CHECK(status)                                      \
  {                                                                \
    cutlass::Status error = status;                                \
    if (error != cutlass::Status::kSuccess) {                      \
      std::stringstream ss;                                        \
      ss << "Got cutlass error: " << cutlassGetStatusString(error) \
         << " at: " << __LINE__ << std::endl;                      \
      throw std::runtime_error(ss.str());                          \
    }                                                              \
  }

void matmul_mixed(torch::Tensor tensor_a, torch::Tensor tensor_b,
                torch::Tensor tensor_d, torch::Tensor tensor_scale, int64_t k,
                int64_t group_size) {
  using MmaType = cutlass::half_t;
  using QuantType = cutlass::int4b_t;
  constexpr int TileShapeK = (128 * 8) / cutlass::sizeof_bits<MmaType>::value;

  // A matrix configuration
  using ElementA = QuantType;
  using LayoutA = cutlass::layout::RowMajor;
  constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

  // B matrix configuration
  using ElementB = MmaType;
  using LayoutB = cutlass::layout::ColumnMajor;
  constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

  using ElementZero = cutlass::half_t;
  using ElementScale = cutlass::half_t;
  using LayoutScale = cutlass::layout::RowMajor;

  // C/D matrix configuration
  using ElementD = cutlass::half_t;
  using LayoutD = cutlass::layout::ColumnMajor;
  constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;

  using ElementC = void;
  using LayoutC = LayoutD;  // Layout type for C and D matrix operands
  constexpr int AlignmentC = AlignmentD;

  // Core kernel configurations
  using ElementAccumulator = float;  // Element type for internal accumulation
  using ElementCompute = float;      // Element type for epilogue computation
  using ArchTag = cutlass::arch::Sm90;

  using OperatorClass = cutlass::arch::OpClassTensorOp;
  using TileShape = cute::Shape<cute::_128, cute::_16, cute::Int<TileShapeK>>;
  using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
  using KernelSchedule =
      cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;

  using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
  using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;

  using CollectiveEpilogue =
      typename cutlass::epilogue::collective::CollectiveBuilder<
          cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
          ClusterShape, EpilogueTileType, ElementAccumulator,
          ElementAccumulator, ElementC, LayoutC, AlignmentC, ElementD, LayoutD,
          AlignmentD, EpilogueSchedule>::CollectiveOp;

  using StageCount = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
              sizeof(typename CollectiveEpilogue::SharedStorage))>;

  // The Scale information must get paired with the operand A that will be
  // scaled.
  using CollectiveMainloop =
      typename cutlass::gemm::collective::CollectiveBuilder<
          ArchTag, OperatorClass, cute::tuple<ElementA, ElementScale>, LayoutA,
          AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
          TileShape, ClusterShape, StageCount, KernelSchedule>::CollectiveOp;

  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
      cute::Shape<int, int, int, int>,  // Indicates ProblemShape
      CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::StreamKScheduler>;

  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

  /// Initialization
  typename GemmKernel::StrideA stride_a{};
  cute::get<0>(stride_a) = static_cast<int>(k);
  typename GemmKernel::StrideB stride_b{};
  cute::get<0>(stride_b) = static_cast<int>(tensor_b.stride(1));
  typename Gemm::GemmKernel::StrideC stride_c{};
  typename Gemm::GemmKernel::StrideD stride_d{};
  cute::get<1>(stride_d) = static_cast<int>(tensor_d.stride(1));

  // Scale and Zero share a stride since the layout and shapes must be the same.
  using StrideS = typename CollectiveMainloop::StrideScale;
  StrideS stride_s;

  // Data
  auto data_a = reinterpret_cast<const cutlass::int4b_t *>(tensor_a.data_ptr());
  auto data_b = reinterpret_cast<const cutlass::half_t *>(tensor_b.data_ptr());
  auto data_c = nullptr;
  auto data_d = reinterpret_cast<const cutlass::half_t *>(tensor_d.data_ptr());
  auto data_scale =
      reinterpret_cast<const cutlass::half_t *>(tensor_scale.data_ptr());


  typename Gemm::Arguments args;
  args.mode = cutlass::gemm::GemmUniversalMode::kGemm;
  args.problem_shape = {static_cast<int>(tensor_a.size(0)),
                        static_cast<int>(tensor_b.size(1)), static_cast<int>(k),
                        1};
  args.mainloop = {data_a,
                   stride_a,
                   data_b,
                   stride_b,
                   data_scale,
                   stride_s,
                   static_cast<int>(group_size),
                   nullptr,
                   4 /*mma_promotion_interval*/};
  args.epilogue = {{1, 0} /*alpha, beta*/, data_c, stride_c, data_d, stride_d};

  Gemm gemm;

  auto ws_size = static_cast<int64_t>(Gemm::get_workspace_size(args));
  auto ws_tensor = at::empty({ws_size}, at::TensorOptions()
                                            .dtype(at::ScalarType::Byte)
                                            .device(tensor_d.device())
                                            .requires_grad(false));

  CUTLASS_CHECK(gemm.can_implement(args));
  CUTLASS_CHECK(gemm.initialize(args, ws_tensor.data_ptr()));
  CUTLASS_CHECK(gemm.run());
}

}  // namespace

}  // namespace cuscratch

PYBIND11_MODULE(cuscratch, m) { m.def("matmul_mixed", cuscratch::matmul_mixed); }

@divchenko
Copy link
Author

@IonThruster for fp8 version, you can just look at my old post #1139

@rawnhenry
Copy link

rawnhenry commented May 25, 2024

@divchenko This behavior is expected with the current implementation. I not done a deep dive into the performance, but I have a theory that may explain the behavior you observe.

If we take a compute bound case, we typically have a MMA tile of MxNxK = 64x256x32, which means A's tile is 64x32 and B's tile is 256x32. We must convert 64x32 elements of A from INT4 to FP8, but in the compute bound case, we can hide that latency behind loading the large B matrix from smem, and the big tensor core instruction. This is because the Hopper TCs are asynchronous, so while we do MMA for stage k, we can be converting the data for stage k+1.

In the memory bound case, we have much smaller tiles. In your example, it is 64x16x32. It means A's tile size is still 64x32 but B is way smaller at 16x32. The amount of A data we must convert is exactly the same as before, but we can no longer hide this latency behind a big tensor core instruction. I think the extra exposed latency is causing the slowdown in the memory bound case

My theory is that the conversion cost is exposed in the memory bound case.

DISCLAIMER: I don't have data supporting what I've said above. It could be completely wrong, but it is just a hunch :)

@divchenko
Copy link
Author

Thanks @rawnhenry . The memory-bound case for fp8 (where I have 64x16x256 tiles) actually works quite well reaching closed to 60% memory b/w. It's the mixed precision case w/ tile 128x16x64 (k tile is restricted to be at most 64 == scaling group size), which doesn't work well.
Two options I see:

  1. If I don't use stream-k tile scheduler then my occupancy is same as in fp8 case (~50% of grid, but seems enough to saturate HBM), but looks like because of small-ish K tile (64 as compared to 256), as you mentioned, the latency is not hidden well.
  2. If I use stream-k tile scheduler then occupancy is full, but, again, tiles are quite small and likely conversions and k tile streaming are not hidden well.

Copy link

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

Copy link

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants