Skip to content

Commit e14e255

Browse files
committed
merge in fix-graph-breaks
2 parents 8dd4382 + 0168f9e commit e14e255

24 files changed

+459
-148
lines changed

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,16 @@ FIX #xxxx (*link existing issues this PR will resolve*)
3939
<li>Please add documentation to <code>docs/source/</code> if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.</li>
4040
</ul>
4141

42+
<h3>Adding or changing kernels</h3>
43+
<p>Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.</p>
44+
<ul>
45+
<li>Make sure custom ops are registered following PyTorch guidelines: <a href="https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial">Custom C++ and CUDA Operators</a> and <a href="https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU">The Custom Operators Manual</a></li>
46+
<li>Custom operations that return <code>Tensors</code> require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.</li>
47+
<li>Use <a href="https://pytorch.org/docs/stable/library.html#torch.library.opcheck"><code>torch.libary.opcheck()</code></a> to test the function registration and meta-function for any registered ops. See <code>tests/kernels</code> for examples.</li>
48+
<li>When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.</li>
49+
<li>If a new custom type is needed, see the following document: <a href="https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA">Custom Class Support in PT2</a>.
50+
</ul>
51+
4252
<h3>Notes for Large Changes</h3>
4353
<p>Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with <code>rfc-required</code> and might not go through the PR.</p>
4454

csrc/cpu/torch_bindings.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
2727
// PagedAttention V2.
2828
ops.def(
2929
"paged_attention_v2("
30-
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
31-
" Tensor tmp_out, Tensor query, Tensor key_cache,"
30+
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
31+
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
3232
" Tensor value_cache, int num_kv_heads, float scale,"
3333
" Tensor block_tables, Tensor seq_lens, int block_size,"
3434
" int max_seq_len, Tensor? alibi_slopes,"
@@ -95,8 +95,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
9595

9696
// Copy the cache blocks from src to dst.
9797
cache_ops.def(
98-
"copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
99-
"block_mapping) -> ()");
98+
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
99+
"Tensor block_mapping) -> ()");
100100
cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);
101101

102102
// Reshape the key and value tensors and cache them.

csrc/custom_all_reduce.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,17 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
146146
return {handles, std::move(offsets)};
147147
}
148148

149+
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta_meta(
150+
fptr_t _fa) {
151+
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
152+
auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
153+
auto options =
154+
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
155+
auto handles =
156+
torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
157+
return {handles, std::move(offsets)};
158+
}
159+
149160
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
150161
const std::vector<std::vector<int64_t>>& offsets) {
151162
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);

csrc/ops.h

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -70,28 +70,15 @@ torch::Tensor aqlm_dequant(
7070
const torch::Tensor& codes, const torch::Tensor& codebooks,
7171
const std::vector<int64_t>& codebook_partition_sizes);
7272

73-
torch::Tensor aqlm_dequant_meta(const torch::Tensor& codes,
74-
const torch::Tensor& codebooks,
75-
const torch::Tensor& codebook_partition_sizes);
76-
7773
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
7874
torch::Tensor _scaling_factors, torch::Tensor _zeros,
7975
int64_t split_k_iters);
8076

81-
torch::Tensor awq_gemm_meta(torch::Tensor _in_feats, torch::Tensor _kernel,
82-
torch::Tensor _scaling_factors,
83-
torch::Tensor _zeros, int64_t split_k_iters);
84-
8577
torch::Tensor awq_dequantize(torch::Tensor _kernel,
8678
torch::Tensor _scaling_factors,
8779
torch::Tensor _zeros, int64_t split_k_iters,
8880
int64_t thx, int64_t thy);
8981

90-
torch::Tensor awq_dequantize_meta(torch::Tensor _kernel,
91-
torch::Tensor _scaling_factors,
92-
torch::Tensor _zeros, int64_t split_k_iters,
93-
int64_t thx, int64_t thy);
94-
9582
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
9683
torch::Tensor& b_scales, torch::Tensor& workspace,
9784
int64_t size_m, int64_t size_n, int64_t size_k);
@@ -123,11 +110,6 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
123110
int64_t size_m, int64_t size_n,
124111
int64_t size_k);
125112

126-
torch::Tensor gptq_marlin_24_gemm_meta(
127-
torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta,
128-
torch::Tensor& b_scales, torch::Tensor& workspace, int64_t num_bits,
129-
int64_t size_m, int64_t size_n, int64_t size_k);
130-
131113
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
132114
torch::Tensor& b_scales, torch::Tensor& b_zeros,
133115
torch::Tensor& g_idx, torch::Tensor& perm,
@@ -137,23 +119,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
137119
bool is_k_full, bool has_zp,
138120
bool use_fp32_reduce);
139121

140-
torch::Tensor gptq_marlin_gemm_meta(
141-
torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales,
142-
torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm,
143-
torch::Tensor& workspace, int64_t num_bits, int64_t size_m, int64_t size_n,
144-
int64_t size_k, bool is_k_full, bool has_zp);
145-
146122
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
147123
int64_t size_k, int64_t size_n,
148124
int64_t num_bits);
149125

150126
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
151-
torch::Tensor& perm, int64_t size_k,
152-
int64_t size_n, int64_t num_bits);
127+
torch::Tensor& perm, c10::SymInt size_k,
128+
c10::SymInt size_n, int64_t num_bits);
153129

154130
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
155131
int64_t size_n, int64_t num_bits);
156132

133+
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
134+
c10::SymInt size_k, c10::SymInt size_n,
135+
int64_t num_bits);
136+
157137
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
158138
int64_t n);
159139

@@ -168,12 +148,6 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
168148
int64_t num_bits, int64_t size_m, int64_t size_n,
169149
int64_t size_k);
170150

171-
torch::Tensor fp8_marlin_gemm_meta(torch::Tensor& a, torch::Tensor& b_q_weight,
172-
torch::Tensor& b_scales,
173-
torch::Tensor& workspace, int64_t num_bits,
174-
int64_t size_m, int64_t size_n,
175-
int64_t size_k);
176-
177151
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
178152

179153
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
@@ -281,6 +255,8 @@ void register_buffer(fptr_t _fa, torch::Tensor& t,
281255
const std::vector<int64_t>& offsets);
282256
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
283257
fptr_t _fa);
258+
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta_meta(
259+
fptr_t _fa);
284260
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
285261
const std::vector<std::vector<int64_t>>& offsets);
286262
#endif

csrc/quantization/gptq_marlin/awq_marlin_repack.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,14 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
268268

269269
#endif
270270

271-
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight, int64_t size_k,
272-
int64_t size_n, int64_t num_bits) {
271+
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
272+
c10::SymInt size_k, c10::SymInt size_n,
273+
int64_t num_bits) {
273274
int const pack_factor = 32 / num_bits;
274275
auto options = torch::TensorOptions()
275276
.dtype(b_q_weight.dtype())
276277
.device(b_q_weight.device());
277-
return torch::empty(
278+
return torch::empty_symint(
278279
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
279280
options);
280281
}

csrc/quantization/gptq_marlin/gptq_marlin_repack.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,13 +344,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
344344
#endif
345345

346346
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
347-
torch::Tensor& perm, int64_t size_k,
348-
int64_t size_n, int64_t num_bits) {
347+
torch::Tensor& perm, c10::SymInt size_k,
348+
c10::SymInt size_n, int64_t num_bits) {
349349
int const pack_factor = 32 / num_bits;
350350
auto options = torch::TensorOptions()
351351
.dtype(b_q_weight.dtype())
352352
.device(b_q_weight.device());
353-
return torch::empty({size_k / marlin::tile_size,
354-
size_n * marlin::tile_size / pack_factor},
355-
options);
353+
return torch::empty_symint(
354+
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
355+
options);
356356
}

0 commit comments

Comments
 (0)