Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -377,8 +377,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

# FP4 Archs and flags
cuda_archs_loose_intersection(FP4_ARCHS "10.0;10.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND FP4_ARCHS)
set(SRCS
"csrc/quantization/fp4/fp4_gemm_template.cu"
"csrc/quantization/fp4/quantization.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${FP4_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
message(STATUS "Building FP4 for archs: ${FP4_ARCHS}")
else()
message(STATUS "Not building FP4 as no compatible archs were found.")
endif()

#
# Machete kernels

# The machete kernels only work on hopper and require CUDA 12.0 or later.
Expand Down
72 changes: 72 additions & 0 deletions csrc/quantization/fp4/cudaUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#pragma once

#include "cutlass/cutlass.h"
#include <climits>

/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlassGetStatusString(status))}

inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

class CudaException : public std::runtime_error {
public:
CudaException(const std::string& file, int line, const std::string& message)
: std::runtime_error("CUDA Error at " + file + ":" +
std::to_string(line) + " - " + message) {}
};

template <typename T>
void check(T result, const char* func, const char* file, int line) {
if (result != cudaSuccess) {
throw CudaException(
file, line,
std::string("[TensorRT-LLM][ERROR] CUDA runtime error in ") + func +
": " + cudaGetErrorString(static_cast<cudaError_t>(result)));
}
}

template <typename T>
void checkEx(T result, std::initializer_list<T> const& validReturns,
char const* const func, char const* const file, int const line) {
if (std::all_of(std::begin(validReturns), std::end(validReturns),
[&result](T const& t) { return t != result; })) {
throw TllmException(
file, line,
fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func,
_cudaGetErrorEnum(result)));
}
}

#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
#define check_cuda_error_2(val, file, line) check((val), #val, file, line)
#define sync_check_cuda_error() \
tensorrt_llm::common::syncAndCheck(__FILE__, __LINE__)

inline int getMaxSharedMemoryPerBlockOptin() {
int device_id;
int max_shared_memory_per_block;
check_cuda_error(cudaGetDevice(&device_id));
check_cuda_error(cudaDeviceGetAttribute(
&max_shared_memory_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin,
device_id));
return max_shared_memory_per_block;
}

inline int getSMVersion() {
int device{-1};
check_cuda_error(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
check_cuda_error(cudaDeviceGetAttribute(
&sm_major, cudaDevAttrComputeCapabilityMajor, device));
check_cuda_error(cudaDeviceGetAttribute(
&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
return sm_major * 10 + sm_minor;
}
68 changes: 68 additions & 0 deletions csrc/quantization/fp4/fp4_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#pragma once

/*
This runner supports:
FP4 inputs (A and B)
float blockwise scaling factor
float alpha scalings
T output (D) where T = {float, half, __nv_bfloat16}

Activations, biases and outputs are all assumed to be row-major.
Weights are assumed to be column-major.
Block scaling factor are interleaved.
*/

class CutlassFp4GemmRunnerInterface {
public:
CutlassFp4GemmRunnerInterface() {}

virtual ~CutlassFp4GemmRunnerInterface() {}

virtual void cutlass_scaled_fp4_mm(
void* D, void const* A, void const* B, void const* input_sf,
void const* weight_sf, float const* global_sf, int m, int n, int k,
char* workspace, const size_t workspaceBytes, cudaStream_t stream) = 0;

// Returns desired workspace size in bytes.
virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0;

virtual std::vector<CutlassGemmConfig> getConfigs() const = 0;
};

template <typename T>
class CutlassFp4GemmRunner : public virtual CutlassFp4GemmRunnerInterface {
public:
CutlassFp4GemmRunner();
~CutlassFp4GemmRunner();

void CutlassFp4GemmRunner::cutlass_scaled_fp4_mm(
torch::Tensor& D, torch::Tensor& A, torch::Tensor& B,
torch::Tensor& input_sf, torch::Tensor& weight_sf,
torch::Tensor& global_sf, CutlassGemmConfig gemmConfig,
torch::Tensor& workspace, const size_t workspaceBytes);

void cutlass_scaled_fp4_mm(void* D, void const* A, void const* B,
void const* input_sf, void const* weight_sf,
float const* global_sf, int m, int n, int k,
CutlassGemmConfig gemmConfig, char* workspace,
const size_t workspaceBytes,
cudaStream_t stream) override;

// Returns desired workspace size in bytes.
size_t getWorkspaceSize(int const m, int const n, int const k) override;

std::vector<CutlassGemmConfig> getConfigs() const override;

private:
size_t dispatchToArch(T* D, void const* A, void const* B,
void const* input_sf, void const* weight_sf,
float const* global_sf, int m, int n, int k,
CutlassGemmConfig gemmConfig, char* workspace,
const size_t workspaceBytes, cudaStream_t stream,
int* occupancy = nullptr);

size_t getWorkspaceSizeImpl(int const m, int const n, int const k);

int mSm;
int mMultiProcessorCount;
};
Loading
Loading