Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cpp/serve/engine_actions/batch_decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ class BatchDecodeActionObj : public EngineActionObj {
// Fill range [0, num_rsentries) into `sample_indices`.
std::vector<int> sample_indices(num_rsentries);
std::iota(sample_indices.begin(), sample_indices.end(), 0);
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP(
probs_on_device, sample_indices, request_ids, generation_cfg, rngs);
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
probs_on_device, sample_indices, request_ids, generation_cfg);
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
ICHECK_EQ(sample_results.size(), num_rsentries);

// - Update the committed tokens of states.
Expand Down
6 changes: 4 additions & 2 deletions cpp/serve/engine_actions/new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ class NewRequestPrefillActionObj : public EngineActionObj {
rsentry_activated.push_back(true);
}
}
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP(
probs_on_device, sample_indices, request_ids, generation_cfg, rngs);
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
probs_on_device, sample_indices, request_ids, generation_cfg);
std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
ICHECK_EQ(sample_results.size(), rsentries_for_sample.size());

// - Update the committed tokens of states.
Expand Down
40 changes: 32 additions & 8 deletions cpp/serve/sampler/gpu_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class GPUSampler : public SamplerObj {
uniform_samples_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu);
sample_indices_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu);
top_p_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu);
top_p_init_pivots_host_ =
NDArray::Empty({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_, device_cpu);
top_prob_offsets_host_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device_cpu);
draft_tokens_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu);
token_tree_first_child_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu);
Expand All @@ -73,6 +75,8 @@ class GPUSampler : public SamplerObj {
uniform_samples_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device);
sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device);
top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device);
top_p_init_pivots_device_ =
NDArray::Empty({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_, device);
top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device);
draft_tokens_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device);
token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device);
Expand Down Expand Up @@ -118,21 +122,35 @@ class GPUSampler : public SamplerObj {
return probs_on_device;
}

// - Argsort the probability.
Array<NDArray> argsort_results = gpu_argsort_probs_func_(probs_on_device);
ICHECK_EQ(argsort_results.size(), 2);
NDArray sorted_probs_on_device = argsort_results[0];
NDArray sorted_indices_on_device = argsort_results[1];

// - Copy auxiliary array for top-p.
// - Copy auxiliary array for top-p and initial pivots.
NDArray top_p_host = top_p_host_.CreateView({num_probs}, dtype_f32_);
NDArray top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_);
CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device, copy_stream_);

NDArray top_p_init_pivots_host =
top_p_init_pivots_host_.CreateView({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_);
NDArray top_p_init_pivots_device =
top_p_init_pivots_device_.CreateView({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_);
const float* p_top_p = static_cast<const float*>(top_p_host->data);
float* p_top_p_init_pivots = static_cast<float*>(top_p_init_pivots_host->data);
for (int i = 0; i < num_probs; ++i) {
if (1 - p_top_p[i] >= 0.02) {
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] =
std::min(1 - p_top_p[i], static_cast<float>(0.5));
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1] = 0.02;
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2] = 0.01;
} else {
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] = 1 - p_top_p[i];
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1] = (1 - p_top_p[i]) / 2;
p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2] = (1 - p_top_p[i]) / 4;
}
}
CopyArray(/*src=*/top_p_init_pivots_host, /*dst=*/top_p_init_pivots_device, copy_stream_);
SyncCopyStream(device_, compute_stream_, copy_stream_);

// - Renormalize the prob with top p.
NDArray renormed_probs_on_device =
gpu_renormalize_by_top_p_func_(probs_on_device, sorted_probs_on_device, top_p_device);
gpu_renormalize_by_top_p_func_(probs_on_device, top_p_device, top_p_init_pivots_device);

RECORD_EVENT(trace_recorder_, request_ids, "finish renormalization by top p");
return renormed_probs_on_device;
Expand Down Expand Up @@ -500,6 +518,9 @@ class GPUSampler : public SamplerObj {
<< "GPU sampler requires the top_p values for each prob distribution are the same.";
}
}
for (int i = 0; i < num_probs; ++i) {
p_top_p[i] = std::max(p_top_p[i], eps_);
}
return need_top_p;
}

Expand Down Expand Up @@ -665,6 +686,7 @@ class GPUSampler : public SamplerObj {
NDArray uniform_samples_host_;
NDArray sample_indices_host_;
NDArray top_p_host_;
NDArray top_p_init_pivots_host_;
NDArray top_prob_offsets_host_;
NDArray draft_tokens_host_;
NDArray token_tree_first_child_host_;
Expand All @@ -678,6 +700,7 @@ class GPUSampler : public SamplerObj {
NDArray uniform_samples_device_;
NDArray sample_indices_device_;
NDArray top_p_device_;
NDArray top_p_init_pivots_device_;
NDArray top_prob_offsets_device_;
NDArray draft_tokens_device_;
NDArray token_tree_first_child_device_;
Expand All @@ -691,6 +714,7 @@ class GPUSampler : public SamplerObj {
// The device stream for copying auxiliary data structure to GPU.
TVMStreamHandle copy_stream_ = nullptr;
const float eps_ = 1e-5;
const int num_top_p_cutoff_pivots_ = 3;
};

Sampler Sampler::CreateGPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft,
Expand Down
58 changes: 27 additions & 31 deletions python/mlc_llm/compiler_pass/attach_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from tvm.relax.frontend import nn
from tvm.script import tir as T

from ..op.batch_spec_verify import batch_spec_verify
from mlc_llm.op.batch_spec_verify import batch_spec_verify
from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm


@tvm.transform.module_pass(opt_level=0, name="AttachGPUSamplingFunc")
Expand Down Expand Up @@ -49,7 +50,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
_attach_sample_with_top_p(bb, vocab_size),
_attach_take_probs_func(bb, vocab_size),
_attach_batch_verifier(bb, vocab_size),
_attach_renormalize_by_top_p(bb, vocab_size),
_attach_renormalize_by_top_p(bb, vocab_size, self.target),
]
]

Expand Down Expand Up @@ -227,41 +228,36 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals
return gv


def _attach_renormalize_by_top_p(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr):
def _attach_renormalize_by_top_p(
bb: relax.BlockBuilder, vocab_size: tir.PrimExpr, target: tvm.target.Target
):
batch_size = tir.Var("batch_size", "int64")
num_pivots = 3
probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32"))
sorted_probs = relax.Var(
"sorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")
)
top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32"))
with bb.function("renormalize_by_top_p", [probs, sorted_probs, top_p]):
init_pivots = relax.Var(
"init_pivots", relax.TensorStructInfo((batch_size, num_pivots), "float32")
)
with bb.function("renormalize_by_top_p", [probs, top_p, init_pivots]):
with bb.dataflow():
probs_tensor = nn.wrap_nested(probs, name="probs")
sorted_probs_tensor = nn.wrap_nested(sorted_probs, name="sorted_probs")
top_p_shape = relax.ShapeExpr([batch_size, 1])
top_p_tensor = nn.wrap_nested(
relax.call_pure_packed(
"vm.builtin.reshape",
top_p,
top_p_shape,
sinfo_args=relax.TensorStructInfo(top_p_shape, "float32"),
),
name="sample_indices",
)
top_k_tensor = nn.tensor_ir_op(
full,
name_hint="full",
args=[vocab_size],
out=nn.Tensor.placeholder(
[batch_size, 1],
"int32",
),
cutoff_output = bb.emit(
relax.call_tir(
bb.add_func(top_p_pivot(num_pivots, target), "top_p_pivot_cutoff"),
args=[probs, top_p, init_pivots],
out_sinfo=[top_p.struct_info, top_p.struct_info], # pylint: disable=no-member
)
)
renormalized_probs = nn.renormalize_top_p_top_k_prob(
probs_tensor, sorted_probs_tensor, top_p_tensor, top_k_tensor
final_pivot = cutoff_output[0]
renorm_sum = cutoff_output[1]
renormalized_probs = bb.emit(
relax.call_tir(
bb.add_func(top_p_renorm(target), "top_p_renorm_after_cutoff"),
args=[probs, final_pivot, renorm_sum],
out_sinfo=probs.struct_info, # pylint: disable=no-member
)
)
bb.emit_output(renormalized_probs._expr) # pylint: disable=protected-access
gv = bb.emit_func_output(renormalized_probs._expr) # pylint: disable=protected-access
bb.emit_output(renormalized_probs)
gv = bb.emit_func_output(renormalized_probs)
return gv


Expand Down
9 changes: 9 additions & 0 deletions python/mlc_llm/compiler_pass/rewrite_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-re
def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements
target: tvm.target.Target, chunk_size: int
):
# NOTE: A quick note on the softmax implementation.
# We once tried to multiply every element by log2e which can be computed
# potentially more efficiently on hardware.
# However, when the input values are large, multiplying by the factor of log2e
# causes numerical issue in float32 dtype.
# This leads to the softmax output not summing up to 1.
# For numerical stability, we removed the log2e factor and switched back
# to the standard log/exp computation.

# pylint: disable=invalid-name
@T.prim_func
def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=too-many-locals
Expand Down
42 changes: 33 additions & 9 deletions python/mlc_llm/op/top_p_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import tvm
from tvm.script import tir as T

from mlc_llm.support.max_thread_check import get_max_num_threads_per_block

# mypy: disable-error-code="attr-defined,valid-type,name-defined"
# pylint: disable=too-many-locals,invalid-name,too-many-arguments,unnecessary-lambda
# pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches


def top_p_pivot(pN):
def top_p_pivot(pN, target: tvm.target.Target):
"""Top-p pivot function. This function finds the pivot to cut-off top-p percentile.

A valide pivot should satisfy the following conditions:
Expand All @@ -23,19 +25,26 @@ def top_p_pivot(pN):
prob:
The probability vector

top_p_global:
top_p_arr:
The top-p threshold

init_pivots:
The initial pivot candidates

final_pivot:
The final pivot to cut-off top-p percentile

final_lsum:
The final sum of the values after top-p filtering.
"""
TX = 1024
K = 32
eps_LR = 1e-7

max_num_threads_per_block = get_max_num_threads_per_block(target)
if max_num_threads_per_block < TX:
TX = max_num_threads_per_block

def _var(dtype="int32"):
return T.alloc_buffer((1,), dtype, scope="local")

Expand All @@ -46,7 +55,7 @@ def valid(lsum, lmin, cmin, top_p):
@T.prim_func(private=True)
def _func(
var_prob: T.handle,
top_p_global: T.buffer([1], dtype="float32"),
var_top_p_arr: T.handle,
var_init_pivots: T.handle,
var_final_pivot: T.handle,
var_final_lsum: T.handle,
Expand All @@ -55,7 +64,8 @@ def _func(
B = T.int32()
N = T.int32()
prob = T.match_buffer(var_prob, (B, N,), "float32")
init_pivots = T.match_buffer(var_init_pivots, (pN,), "float32")
top_p_arr = T.match_buffer(var_top_p_arr, (B,), dtype="float32")
init_pivots = T.match_buffer(var_init_pivots, (B, pN), "float32")
final_pivot = T.match_buffer(var_final_pivot, (B,), "float32")
final_lsum = T.match_buffer(var_final_lsum, (B,), "float32")

Expand Down Expand Up @@ -92,7 +102,7 @@ def _func(
with T.block("CTA"):
b, tx = T.axis.remap("SS", [_bx, _tx])

top_p[0] = top_p_global[0]
top_p[0] = top_p_arr[b]

if tx == 0:
# leader thread initializes L, R
Expand All @@ -105,8 +115,14 @@ def _func(
R_local[0] = R[0]
for i in T.unroll(0, pN):
# pivots are in descending order
pivot[i] = init_pivots[i]
pivot[i] = init_pivots[b, i]
find_pivot_local[0] = False
if L_local[0] - R_local[0] <= eps_LR:
# When the initial value is too small, set the result directly.
if tx == 0:
final_lsum[b] = 1.0
final_pivot[b] = 0.0
find_pivot_local[0] = True

while T.tvm_thread_invariant(
L_local[0] - R_local[0] > eps_LR
Expand All @@ -118,7 +134,7 @@ def _func(
### get lsum, lmin, total_sum
for pidx in T.unroll(0, pN):
lsum[pidx] = 0.0
lmin[pidx] = 1.0
lmin[pidx] = T.max_value("float32")
cmin[pidx] = 0
total_sum[0] = 0.0
it[0] = 0
Expand Down Expand Up @@ -226,6 +242,7 @@ def _func(
final_lsum[b] = lsum[pidx]
elif lsum[pidx] - lmin[pidx] * cmin[pidx] >= top_p[0]:
R[0] = pivot[pidx]
final_lsum[b] = lsum[pidx]
elif lsum[pidx] < top_p[0]:
L[0] = pivot[pidx]
it[0] += 1
Expand All @@ -243,13 +260,15 @@ def _func(
if tx == 0:
# leader thread writes back the pivot
if T.Not(find_pivot_local[0]):
final_pivot[b] = -1e5
final_pivot[b] = R_local[0]
if R_local[0] == eps_LR:
final_lsum[b] = lsum[pN - 1]
# fmt: on

return _func


def top_p_renorm():
def top_p_renorm(target: tvm.target.Target = None):
"""Top-p renormalization function. This function renormalizes the probability vector.

Given the pivot, the probability vector is renormalized as follows:
Expand All @@ -273,6 +292,11 @@ def top_p_renorm():
TX = 1024
CTA_COUNT = 512

if target:
max_num_threads_per_block = get_max_num_threads_per_block(target)
if max_num_threads_per_block < TX:
TX = max_num_threads_per_block

def _var(dtype="int32"):
return T.alloc_buffer((1,), dtype, scope="local")

Expand Down