Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch custom_op fix for rope #569

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions flashinfer-aot/csrc_aot/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,23 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);

std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);

void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length);

void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta);

void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length);

torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);

Expand Down
34 changes: 14 additions & 20 deletions python/csrc/flashinfer_rope_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,23 @@

#include <vector>

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);
void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length);

std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta);
void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);
void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("apply_rope", &apply_rope, "Apply RoPE");
Expand Down
42 changes: 14 additions & 28 deletions python/csrc/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@

using namespace flashinfer;

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta) {
void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
Expand Down Expand Up @@ -65,14 +64,11 @@ std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::T
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}

std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta) {
void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(pos_ids);
Expand Down Expand Up @@ -109,16 +105,12 @@ std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length) {
void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
Expand Down Expand Up @@ -162,16 +154,12 @@ std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}

std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length) {
void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(pos_ids);
Expand Down Expand Up @@ -209,6 +197,4 @@ std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Te
std::string(cudaGetErrorString(status)));
return true;
});

return {q_rope, k_rope};
}
Loading