Skip to content

Commit

Permalink
perf: Optimize tensor conversions in C++ code to avoid unnecessary co…
Browse files Browse the repository at this point in the history
…pies (#366)

Small tweak to avoid unnecessary copying by combining `to` calls.
Discovered during profiling.
  • Loading branch information
Yard1 authored Jul 10, 2024
1 parent 264082e commit 1116237
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
torch::Tensor o = torch::empty_like(q);
torch::Tensor lse;
if (return_lse) {
lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32);
lse = torch::empty({batch_size, num_qo_heads}, q.options().dtype((torch::kFloat32)));
}

TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
Expand Down
16 changes: 8 additions & 8 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
CHECK_DIM(1, qo_indptr);
CHECK_DIM(1, workspace_buffer);
qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32);
paged_kv_indptr = paged_kv_indptr.to(torch::kCPU).to(torch::kInt32);
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
paged_kv_indptr = paged_kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
auto device = workspace_buffer.device();
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
Expand Down Expand Up @@ -111,7 +111,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32);
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
}
MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone;
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
Expand Down Expand Up @@ -226,7 +226,7 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32);
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
}
constexpr MaskMode MASK_MODE = MaskMode::kCustom;
TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative");
Expand Down Expand Up @@ -288,8 +288,8 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
CHECK_DIM(1, qo_indptr);
CHECK_DIM(1, workspace_buffer);
qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32);
kv_indptr = kv_indptr.to(torch::kCPU).to(torch::kInt32);
qo_indptr = qo_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
kv_indptr = kv_indptr.to(torch::dtype(torch::kInt32).device(torch::kCPU));
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
auto device = workspace_buffer.device();
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
Expand Down Expand Up @@ -354,7 +354,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32);
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype(torch::kFloat32));
}

MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone;
Expand Down Expand Up @@ -452,7 +452,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
torch::Tensor o = torch::empty_like(q, q.options());
torch::Tensor lse = torch::empty({0});
if (return_lse) {
lse = torch::empty({nnz_qo, num_qo_heads}, q.options()).to(torch::kFloat32);
lse = torch::empty({nnz_qo, num_qo_heads}, q.options().dtype((torch::kFloat32)));
}

constexpr MaskMode MASK_MODE = MaskMode::kCustom;
Expand Down

0 comments on commit 1116237

Please sign in to comment.