Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
69d46da
Add meta functions for ops to prevent graph breaks
bnellnm Jul 29, 2024
679470c
format
bnellnm Jul 29, 2024
08f969f
add torch.compile to loader + symint support for gptq_gemm_meta + twe…
bnellnm Jul 29, 2024
7f5946d
pull out punica support test, move torch.compile to runner to avoid w…
bnellnm Aug 1, 2024
2853e5d
tweaks
bnellnm Aug 1, 2024
1559470
change codebook_partition_sizes to List[int]
bnellnm Aug 2, 2024
929894d
use string schemas for all functions
bnellnm Aug 5, 2024
b707b57
back out lora test hacks
bnellnm Aug 5, 2024
43cbb23
cleanups
bnellnm Aug 5, 2024
7d9ab09
fix flash_attn
bnellnm Aug 5, 2024
33ab5fe
fix marlin schemas and meta funcs
bnellnm Aug 5, 2024
69bfa21
fix format
bnellnm Aug 6, 2024
729d99c
add some opcheck tests
bnellnm Aug 6, 2024
e59aa74
fix registrations for non-Tensor ops
bnellnm Aug 6, 2024
b7a851b
rebase + fix gguf registrations
bnellnm Aug 6, 2024
d57d913
update PR template with info on pytorch registration
bnellnm Aug 6, 2024
ea22ab5
try registering meta-function via python to handle symbolic shapes
bnellnm Aug 6, 2024
58fb6b6
format
bnellnm Aug 6, 2024
61ba2ad
conditionally register gptq_marlin_24_gemm_fake
bnellnm Aug 6, 2024
e31eedd
format stuff
bnellnm Aug 6, 2024
624732a
try python meta functions
bnellnm Aug 7, 2024
2469661
temporarily add opchecks to almost all custom ops
bnellnm Aug 7, 2024
6c1213b
comment out opchecks
bnellnm Aug 7, 2024
0c45db3
remove temporary opchecks in _custom_ops
bnellnm Aug 7, 2024
a32cff8
tweak copy_blocks schema
bnellnm Aug 8, 2024
715b731
remove most C++ meta functions
bnellnm Aug 8, 2024
a17e0ac
activation opcheck tests
bnellnm Aug 9, 2024
36091b4
add more opcheck tests
bnellnm Aug 9, 2024
b9e73e7
add more opcheck tests
bnellnm Aug 9, 2024
357a622
run opchecks on fewer combinations to reduce memory use
bnellnm Aug 9, 2024
e9970b5
use @youkaichao's flash_attn registration
bnellnm Aug 9, 2024
a130ca9
fix format
bnellnm Aug 9, 2024
a86d017
fix cutlass test
bnellnm Aug 9, 2024
c32f7d3
add custom op for tensor_modle_parallel_all_reduce
SageMoore Aug 9, 2024
079fc84
format
SageMoore Aug 9, 2024
a6ca952
register lora triton ops to avoid dynamo problems
bnellnm Aug 9, 2024
d1b26f2
fix cpu support in tensor_model_parallel_all_reduce
SageMoore Aug 9, 2024
cbb2be9
format
SageMoore Aug 9, 2024
abe7865
cleanups
bnellnm Aug 5, 2024
88357b3
fix flash_attn signatures
bnellnm Aug 12, 2024
53f3147
rebase + cleanups
bnellnm Aug 12, 2024
908c254
tweaks + add gc.collect() to fix memory profiling errors when dynamo …
bnellnm Aug 13, 2024
d3eb42d
fix broken env var
bnellnm Aug 13, 2024
3dc4141
add clones to all_reduce
SageMoore Aug 14, 2024
6a4ad9c
fix format
bnellnm Aug 13, 2024
9cf3ac9
fix aqlm custom op type annotations
bnellnm Aug 14, 2024
68b3a74
fix gptq custom op registration
bnellnm Aug 14, 2024
3f4cce4
add dynamo support for ScalarType
bnellnm Aug 14, 2024
4015920
add some pointers to PT2 custom class docs
bnellnm Aug 14, 2024
87ccdac
tweaks
bnellnm Aug 16, 2024
38d1bda
fix merge
bnellnm Aug 16, 2024
ed4a565
fix cpu schemas
bnellnm Aug 16, 2024
af6302f
fix merge
bnellnm Aug 17, 2024
0168f9e
rebase + add meta functions for machete kernels
bnellnm Aug 20, 2024
d6243cd
Custom torch.compile backend prototype
bnellnm Apr 25, 2024
5d64ad1
integration wip
bnellnm May 17, 2024
a8a0103
wip
bnellnm May 21, 2024
4dc6caf
wip
bnellnm May 26, 2024
fcb5a03
wip
bnellnm May 27, 2024
9b248c4
wip
bnellnm May 31, 2024
721b38f
wip
bnellnm May 31, 2024
bdfeebc
merge
bnellnm Jun 3, 2024
2dd4d52
wip
bnellnm Jun 5, 2024
02c704c
progress
bnellnm Jun 6, 2024
584a0f0
optimize whole module
bnellnm Jun 6, 2024
8926cdd
remove partitioner
bnellnm Jun 7, 2024
ffd7c6f
wip
bnellnm Jun 9, 2024
73223d2
wip
bnellnm Jun 9, 2024
8b2749b
wip
bnellnm Jun 12, 2024
f54da8c
almost matching
bnellnm Jun 12, 2024
84fc16c
almost matching
bnellnm Jun 12, 2024
c48fb44
fix some stuff
bnellnm Jun 12, 2024
59e13db
flash attn support
bnellnm Jun 13, 2024
a66e257
return instead of throw
bnellnm Jun 13, 2024
fd13948
compile fixes
bnellnm Jun 13, 2024
17a5537
add support for multi output fused ops
bnellnm Jun 14, 2024
c042b2c
comment
bnellnm Jun 14, 2024
d59e905
wip inplace op fusion
bnellnm Jun 15, 2024
fff949c
fix node_users
bnellnm Jun 15, 2024
8f42e32
fix FlowGraph dependencies
bnellnm Jun 19, 2024
0e9b2e1
speed up SubGraph creation by reusing dependencies from FlowGraph
bnellnm Jun 19, 2024
60c14d2
fix input ordering issues for fused ops
bnellnm Jun 19, 2024
bca48ce
fix last_input method
bnellnm Jun 20, 2024
69e01e9
add torch.narrow to supported ops
bnellnm Jun 20, 2024
87bfb74
fp8 fixes
bnellnm Jun 20, 2024
353d7df
fix SubGraph toposort
bnellnm Jun 20, 2024
b270917
add constant arg values to mangled names
bnellnm Jun 20, 2024
81268c3
comments + cleanups
bnellnm Jun 21, 2024
032bcdf
rebase + fixes
bnellnm Jul 9, 2024
72021fb
fix
bnellnm Jul 9, 2024
fd59146
cleanup wip
bnellnm Jul 10, 2024
7120e21
refactor op generator
bnellnm Jul 10, 2024
ee25bf7
simplify generator
bnellnm Jul 11, 2024
e35b4af
kernels and kernel accessories
SageMoore Jul 11, 2024
a10520e
comments
SageMoore Jul 11, 2024
5bd856a
name mangling fixes
bnellnm Jul 11, 2024
d5e7601
add rms_norm support
bnellnm Jul 11, 2024
fbc6404
silu mul quant fixes
SageMoore Jul 12, 2024
d8ec04b
rename ex directory
bnellnm Jul 12, 2024
1aa7432
tweaks
bnellnm Jul 12, 2024
2b64edd
remove hard coded shape from silu mul quant
SageMoore Jul 12, 2024
554c5db
lint fixes
bnellnm Jul 12, 2024
3a03b96
more lint fixes
bnellnm Jul 12, 2024
62094ad
more lint fixes
bnellnm Jul 12, 2024
f905885
break up long lines and hopefully not break code
bnellnm Jul 12, 2024
b06c5e5
more lint fixes
bnellnm Jul 12, 2024
28094ec
don't link directly against vllm shared libs in naive codegen
bnellnm Jul 16, 2024
cd3f434
formatting
bnellnm Jul 16, 2024
c8da18e
deterministic subgraph topo sort
bnellnm Jul 16, 2024
aa12834
make sure things are only registered once
bnellnm Jul 17, 2024
1c1b824
fix bug in node_function_target, add empty __init__.py
bnellnm Jul 17, 2024
5148b44
enable optimizer for all models
bnellnm Jul 17, 2024
cee53ec
apply optimizer to all models + fix some bugs
bnellnm Jul 19, 2024
edb6107
removed const values from fused kernels
SageMoore Jul 22, 2024
c559fb0
cleanup
SageMoore Jul 22, 2024
42fb293
more cleanup
SageMoore Jul 23, 2024
9848f15
disabling const extraction stuff
SageMoore Jul 23, 2024
6a75a38
disabling const extraction stuff
SageMoore Jul 23, 2024
57ac451
wip valid subgraph
bnellnm Jul 24, 2024
e2b3182
fix valid subgraph
bnellnm Jul 24, 2024
8a14f60
temporarily disable optimizer so we can diagnose dynamo issues
bnellnm Jul 25, 2024
47771bb
fix flash_attn missing window_size
bnellnm Jul 26, 2024
feb1e84
fix advance schema
bnellnm Jul 26, 2024
8f42319
update gptq_marlin_gemm schema and meta fn
bnellnm Jul 26, 2024
45e6451
fix flash_attn registration
bnellnm Jul 26, 2024
bea04ba
re-enable silu_and_mul kernel
bnellnm Jul 26, 2024
c84c60c
a few tweaks
bnellnm Jul 26, 2024
325b861
more tweaks
bnellnm Jul 26, 2024
2a465e6
update torch to 2.4
SageMoore Jul 29, 2024
9931233
add some meta functions, revert some pt2.3 hacks
bnellnm Jul 29, 2024
10e73a2
comments
bnellnm Aug 12, 2024
8dd4382
move optimizer call
bnellnm Aug 20, 2024
e14e255
merge in fix-graph-breaks
bnellnm Aug 20, 2024
6793994
symint and other fixes
bnellnm Aug 23, 2024
355436f
add support for torch.Tensor.size
bnellnm Aug 23, 2024
5fa9530
rewrite gather function
bnellnm Aug 30, 2024
7de4709
revert some wip so fp8 model will run
bnellnm Sep 3, 2024
0b3f4bc
wip. TODO: dynamic_fp8 wrong answer
bnellnm Sep 4, 2024
d7fa26d
support floordiv
bnellnm Sep 9, 2024
578e8a1
update fusion passes for llama 3
SageMoore Sep 6, 2024
2e5d996
more fused operators
SageMoore Sep 6, 2024
6cbf5e4
prepare for commit
SageMoore Sep 10, 2024
cbb88ed
remove prints
SageMoore Sep 10, 2024
64a88bb
reorder args in silu_mul_quant
SageMoore Sep 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ FIX #xxxx (*link existing issues this PR will resolve*)
<li>Please add documentation to <code>docs/source/</code> if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.</li>
</ul>

<h3>Adding or changing kernels</h3>
<p>Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.</p>
<ul>
<li>Make sure custom ops are registered following PyTorch guidelines: <a href="https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial">Custom C++ and CUDA Operators</a> and <a href="https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU">The Custom Operators Manual</a></li>
<li>Custom operations that return <code>Tensors</code> require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.</li>
<li>Use <a href="https://pytorch.org/docs/stable/library.html#torch.library.opcheck"><code>torch.libary.opcheck()</code></a> to test the function registration and meta-function for any registered ops. See <code>tests/kernels</code> for examples.</li>
<li>When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.</li>
<li>If a new custom type is needed, see the following document: <a href="https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA">Custom Class Support in PT2</a>.
</ul>

<h3>Notes for Large Changes</h3>
<p>Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with <code>rfc-required</code> and might not go through the PR.</p>

Expand Down
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/layernorm_kernels/rms_norm_quant.cu"
"csrc/quantization/layernorm_kernels/activation_kernels.cu")

#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
Expand Down
8 changes: 4 additions & 4 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
Expand Down Expand Up @@ -95,8 +95,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {

// Copy the cache blocks from src to dst.
cache_ops.def(
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
"block_mapping) -> ()");
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
"Tensor block_mapping) -> ()");
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);

// Reshape the key and value tensors and cache them.
Expand Down
42 changes: 42 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);

torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits);

torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);

torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits);

torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);

Expand Down Expand Up @@ -164,6 +172,35 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
int64_t size_n, int64_t size_k);
#endif

// These are kernels used by qqq
// torch::Tensor qqq_gemm(
// torch::Tensor& a,
// torch::Tensor& b_q_weight,
// torch::Tensor& s1,
// torch::Tensor& s2,
// torch::Tensor& s3,
// torch::Tensor& workspace,
// int64_t size_m,
// int64_t size_n,
// int64_t size_k);

void rms_norm_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& tmp, torch::Tensor const& weight,
torch::Tensor& scale, double const epsilon);

void add_residual_rms_norm_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& residual, torch::Tensor& tmp,
torch::Tensor const& weight,
torch::Tensor& scale, double const epsilon);

void silu_and_mul_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale, torch::Tensor& tmp);

// void quant(
// torch::Tensor& out,
// torch::Tensor& input,
// torch::Tensor& scale);

void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);

Expand All @@ -178,6 +215,11 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit);

torch::Tensor gptq_gemm_meta(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit);

void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);

void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
Expand Down
31 changes: 31 additions & 0 deletions csrc/quantization/aqlm/gemm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,34 @@ torch::Tensor aqlm_dequant(
" entries is not currently supported.")
return {};
}

torch::Tensor aqlm_gemm_meta(const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias) {


auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty(
{flat_input.size(0), out_features},
torch::TensorOptions().dtype(input.dtype()).device(input.device()));

auto output_sizes = input.sizes().vec();
output_sizes.pop_back();
output_sizes.push_back(-1);
return flat_output.reshape(output_sizes);
}

torch::Tensor aqlm_dequant_meta(const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes) {
auto in_features = codes.size(1) * 8;
auto out_features = codes.size(0);
return torch::empty({out_features, in_features},
torch::TensorOptions()
.dtype(codebooks.dtype())
.device(codebooks.device()));
}
26 changes: 26 additions & 0 deletions csrc/quantization/awq/gemm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,29 @@ torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
}
return _out_feats.sum(0);
}

torch::Tensor awq_gemm_meta(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, int64_t split_k_iters) {
auto num_in_feats = _in_feats.size(0);
auto options = torch::TensorOptions()
.dtype(_in_feats.dtype())
.device(_in_feats.device());
return torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8},
options).sum(0);
}

torch::Tensor awq_dequantize_meta(torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, int64_t split_k_iters,
int64_t thx, int64_t thy) {
auto in_c = _kernel.size(0);
auto qout_c = _kernel.size(1);
auto out_c = qout_c * 8;

auto options = torch::TensorOptions()
.dtype(_scaling_factors.dtype())
.device(_scaling_factors.device());

return torch::empty({in_c, out_c}, options);
}
9 changes: 9 additions & 0 deletions csrc/quantization/fp8/fp8_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1303,3 +1303,12 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}

#endif

torch::Tensor fp8_marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_n,
int64_t size_k) {
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
return torch::empty({size_m, size_n}, options);
}
17 changes: 17 additions & 0 deletions csrc/quantization/gptq/q_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1854,3 +1854,20 @@ void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) {
: (int*)q_perm.data_ptr(),
q_weight.size(0) * 32 / bit, q_weight.size(1), bit);
}

torch::Tensor gptq_gemm_meta(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit) {
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
#if 0
// TODO: this might not be quite right, add check for symbolic dims and only
// use when needed?
auto const m = a.sym_size(0);
auto const n = b_q_weight.sym_size(1);
auto res = torch::empty_symint({m, n}, options);
#else
auto res = torch::empty({a.size(0), b_q_weight.size(1)}, options);
#endif
return res;
}
12 changes: 12 additions & 0 deletions csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,15 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
}

#endif

torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
9 changes: 9 additions & 0 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2297,3 +2297,12 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}

#endif

torch::Tensor gptq_marlin_gemm_meta(
torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales,
torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, bool has_zp) {
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
return torch::empty({size_m, size_n}, options);
}
12 changes: 12 additions & 0 deletions csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
}

#endif

torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
114 changes: 114 additions & 0 deletions csrc/quantization/layernorm_kernels/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#include "../../cuda_compat.h"
#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
// #include "quant_utils.cuh"
#ifndef USE_ROCM
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
namespace vllm {

template <bool is_scale_inverted>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}

float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#ifndef USE_ROCM
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
#endif
}

static inline __device__ int8_t float_to_int8_rn(float x) {
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
}

template <typename T>
__device__ __forceinline__ T silu(const T& x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}

template <typename scalar_t>
__global__ void silu_and_mul_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2 * d]
const int d,
float* __restrict__ scale, // [num_tokens]
float* __restrict__ tmp) {
const int64_t token_idx = blockIdx.x;
// float amax_val = 0.0f;

for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
// const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
// scalar_t t = silu(x) * y;
// input[token_idx * 2 * d + idx] = t;
// amax_val = fmaxf(amax_val, fabsf((float) t));
const float x = (float)VLLM_LDG(&input[token_idx * 2 * d + idx]);
const float y = (float)VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
float t = silu(x) * y;
tmp[token_idx * d + idx] = t;
// amax_val = fmaxf(amax_val, fabsf(t));
}

// __shared__ float s_amax;
// amax_val = blockReduceMax(amax_val);
// if (threadIdx.x == 0) {
// s_amax = amax_val;
// // scale[blockIdx.x] = amax_val / 127.0f;
// }
// __syncthreads();

// float tmp_scale = 127.0f / s_amax;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
// out[token_idx * d + idx] =
// float_to_int8_rn(tmp_scale * (float) input[token_idx * 2 * d + idx]);
// out[token_idx * d + idx] =
// float_to_int8_rn(tmp_scale * tmp[token_idx * d + idx]);
out[token_idx * d + idx] = scaled_fp8_conversion<false>(
tmp[token_idx * d + idx], *scale);
}
}
} // namespace vllm

void silu_and_mul_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., 2 * d]
torch::Tensor& scale, // [num_tokens]
torch::Tensor& tmp // [..., d]
) {
int d = input.size(-1) / 2;
int64_t num_tokens = input.numel() / input.size(-1);
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_quant_kernel", [&] {
vllm::silu_and_mul_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), d,
scale.data_ptr<float>(), tmp.data_ptr<float>());
});
}
Empty file.
Loading