Skip to content

Commit

Permalink
sampling: expose sampling APIs in pytorch (#238)
Browse files Browse the repository at this point in the history
We only have C++ and TVM sampling APIs atm, this PR exposes sampling
APIs in PyTorch.
  • Loading branch information
yzh119 authored May 7, 2024
1 parent 15db5de commit 0929023
Show file tree
Hide file tree
Showing 13 changed files with 583 additions and 1 deletion.
13 changes: 13 additions & 0 deletions docs/api/python/norm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _apinorm:

flashinfer.norm
===============

Kernels for normalization layers.

.. currentmodule:: flashinfer.norm

.. autosummary::
:toctree: _generate

rmsnorm
15 changes: 15 additions & 0 deletions docs/api/python/sampling.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
.. _apisampling:

flashinfer.sampling
===================

Kernels for LLM sampling.

.. currentmodule:: flashinfer.sampling

.. autosummary::
:toctree: ../../generated

sampling_from_probs
top_p_sampling_from_probs
top_k_sampling_from_probs
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,5 @@ FlashInfer is a library for Language Languages Models that provides high-perform
api/python/prefill
api/python/cascade
api/python/page

api/python/sampling
api/python/norm
6 changes: 6 additions & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("merge_states", &merge_states, "Merge multiple self-attention states");
m.def("batch_decode_with_padded_kv_cache", &batch_decode_with_padded_kv_cache,
"Multi-request batch decode with padded KV-Cache operator");
m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities");
m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs,
"Top-k sampling from probabilities");
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs,
"Top-p sampling from probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
"BatchDecodeWithPagedKVCachePyTorchWrapper")
.def(py::init(&BatchDecodeWithPagedKVCachePyTorchWrapper::Create))
Expand Down
11 changes: 11 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ std::vector<torch::Tensor> batch_decode_with_padded_kv_cache(
unsigned int pos_encoding_mode, float sm_scale, float rope_scale, float rope_theta,
bool return_lse);

torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples);

std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples, double top_p);

std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
unsigned int top_k);

torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);

class BatchDecodeWithPagedKVCachePyTorchWrapper {
public:
static BatchDecodeWithPagedKVCachePyTorchWrapper Create(unsigned int layout) {
Expand Down
43 changes: 43 additions & 0 deletions python/csrc/norm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/norm.cuh>

#include "flashinfer_ops.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps) {
CHECK_INPUT(x);
CHECK_INPUT(w);
CHECK_DIM(2, x); // x: (batch_size, hidden_size)
CHECK_DIM(1, w); // w: (hidden_size)
CHECK_EQ(x.size(1), w.size(0));
unsigned int batch_size = x.size(0);
unsigned int hidden_size = x.size(1);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto y = torch::empty_like(x);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] {
cudaError_t status = norm::RMSNorm(
static_cast<c_type*>(x.data_ptr()), static_cast<c_type*>(w.data_ptr()),
static_cast<c_type*>(y.data_ptr()), batch_size, hidden_size, eps, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"RMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
return y;
}
98 changes: 98 additions & 0 deletions python/csrc/sampling.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/sampling.cuh>

#include "flashinfer_ops.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(1, uniform_samples); // uniform_samples: (batch_size)
CHECK_EQ(probs.size(0), uniform_samples.size(0));
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(probs.device()));

cudaError_t status = sampling::SamplingFromProb(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(samples.data_ptr()), batch_size, vocab_size, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));
return samples;
}

std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples, double top_p) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(2, uniform_samples); // uniform_samples: (max_top_p_rounds, batch_size)
CHECK_EQ(probs.size(0), uniform_samples.size(1));
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
unsigned int max_top_p_rounds = uniform_samples.size(0);
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(probs.device()));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(probs.device()));

cudaError_t status = sampling::TopPSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), top_p,
batch_size, vocab_size, max_top_p_rounds, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "TopPSamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

return {samples, success};
}

std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
unsigned int top_k) {
CHECK_INPUT(probs);
CHECK_INPUT(uniform_samples);
CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
CHECK_DIM(2, uniform_samples); // uniform_samples: (max_top_k_rounds, batch_size)
CHECK_EQ(probs.size(0), uniform_samples.size(1));
unsigned int batch_size = probs.size(0);
unsigned int vocab_size = probs.size(1);
unsigned int max_top_k_rounds = uniform_samples.size(0);
probs = probs.to(torch::kFloat32);
uniform_samples = uniform_samples.to(torch::kFloat32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
auto samples = torch::empty({batch_size}, torch::dtype(torch::kInt32).device(probs.device()));
auto success = torch::empty({batch_size}, torch::dtype(torch::kBool).device(probs.device()));

cudaError_t status = sampling::TopKSamplingFromProb<float, int>(
static_cast<float*>(probs.data_ptr()), static_cast<float*>(uniform_samples.data_ptr()),
static_cast<int*>(samples.data_ptr()), static_cast<bool*>(success.data_ptr()), top_k,
batch_size, vocab_size, max_top_k_rounds, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "TopKSamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

return {samples, success};
}
6 changes: 6 additions & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
)
from .page import append_paged_kv_cache
from .sampling import (
sampling_from_probs,
top_p_sampling_from_probs,
top_k_sampling_from_probs,
)
from .norm import rmsnorm

try:
from ._build_meta import __version__ as __version__
Expand Down
49 changes: 49 additions & 0 deletions python/flashinfer/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import torch

try:
from . import _kernels
except ImportError as e:
import os
import logging

if os.environ.get("BUILD_DOC", "0") == "1":
_kernels = None
logging.warning("Kernels are not loaded in documentation build mode.")
else:
raise e


def rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
r"""Root mean square normalization.
Parameters
----------
x: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
w: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
Returns
-------
y: torch.Tensor
Normalized tensor, shape (batch_size, hidden_size).
"""
return _kernels.rmsnorm(x, w, eps)
Loading

0 comments on commit 0929023

Please sign in to comment.