diff --git a/cpp/serve/engine_actions/batch_decode.cc b/cpp/serve/engine_actions/batch_decode.cc index ecff914baa..3c5c8fdb5b 100644 --- a/cpp/serve/engine_actions/batch_decode.cc +++ b/cpp/serve/engine_actions/batch_decode.cc @@ -114,8 +114,10 @@ class BatchDecodeActionObj : public EngineActionObj { // Fill range [0, num_rsentries) into `sample_indices`. std::vector sample_indices(num_rsentries); std::iota(sample_indices.begin(), sample_indices.end(), 0); - std::vector sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), num_rsentries); // - Update the committed tokens of states. diff --git a/cpp/serve/engine_actions/new_request_prefill.cc b/cpp/serve/engine_actions/new_request_prefill.cc index f801b1e282..5a5847aaa0 100644 --- a/cpp/serve/engine_actions/new_request_prefill.cc +++ b/cpp/serve/engine_actions/new_request_prefill.cc @@ -229,8 +229,10 @@ class NewRequestPrefillActionObj : public EngineActionObj { rsentry_activated.push_back(true); } } - std::vector sample_results = sampler_->BatchSampleTokensWithProbBeforeTopP( - probs_on_device, sample_indices, request_ids, generation_cfg, rngs); + NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP( + probs_on_device, sample_indices, request_ids, generation_cfg); + std::vector sample_results = sampler_->BatchSampleTokensWithProbAfterTopP( + renormalized_probs, sample_indices, request_ids, generation_cfg, rngs); ICHECK_EQ(sample_results.size(), rsentries_for_sample.size()); // - Update the committed tokens of states. diff --git a/cpp/serve/sampler/gpu_sampler.cc b/cpp/serve/sampler/gpu_sampler.cc index 36cb6e5c0a..1a013a9627 100644 --- a/cpp/serve/sampler/gpu_sampler.cc +++ b/cpp/serve/sampler/gpu_sampler.cc @@ -60,6 +60,8 @@ class GPUSampler : public SamplerObj { uniform_samples_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); sample_indices_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); top_p_host_ = NDArray::Empty({max_num_sample}, dtype_f32_, device_cpu); + top_p_init_pivots_host_ = + NDArray::Empty({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_, device_cpu); top_prob_offsets_host_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device_cpu); draft_tokens_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); token_tree_first_child_host_ = NDArray::Empty({max_num_sample}, dtype_i32_, device_cpu); @@ -73,6 +75,8 @@ class GPUSampler : public SamplerObj { uniform_samples_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); sample_indices_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); top_p_device_ = NDArray::Empty({max_num_sample}, dtype_f32_, device); + top_p_init_pivots_device_ = + NDArray::Empty({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_, device); top_prob_offsets_device_ = NDArray::Empty({max_num_sample * 5}, dtype_i32_, device); draft_tokens_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); token_tree_first_child_device_ = NDArray::Empty({max_num_sample}, dtype_i32_, device); @@ -118,21 +122,35 @@ class GPUSampler : public SamplerObj { return probs_on_device; } - // - Argsort the probability. - Array argsort_results = gpu_argsort_probs_func_(probs_on_device); - ICHECK_EQ(argsort_results.size(), 2); - NDArray sorted_probs_on_device = argsort_results[0]; - NDArray sorted_indices_on_device = argsort_results[1]; - - // - Copy auxiliary array for top-p. + // - Copy auxiliary array for top-p and initial pivots. NDArray top_p_host = top_p_host_.CreateView({num_probs}, dtype_f32_); NDArray top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_); CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device, copy_stream_); + + NDArray top_p_init_pivots_host = + top_p_init_pivots_host_.CreateView({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_); + NDArray top_p_init_pivots_device = + top_p_init_pivots_device_.CreateView({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_); + const float* p_top_p = static_cast(top_p_host->data); + float* p_top_p_init_pivots = static_cast(top_p_init_pivots_host->data); + for (int i = 0; i < num_probs; ++i) { + if (1 - p_top_p[i] >= 0.02) { + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] = + std::min(1 - p_top_p[i], static_cast(0.5)); + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1] = 0.02; + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2] = 0.01; + } else { + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] = 1 - p_top_p[i]; + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1] = (1 - p_top_p[i]) / 2; + p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2] = (1 - p_top_p[i]) / 4; + } + } + CopyArray(/*src=*/top_p_init_pivots_host, /*dst=*/top_p_init_pivots_device, copy_stream_); SyncCopyStream(device_, compute_stream_, copy_stream_); // - Renormalize the prob with top p. NDArray renormed_probs_on_device = - gpu_renormalize_by_top_p_func_(probs_on_device, sorted_probs_on_device, top_p_device); + gpu_renormalize_by_top_p_func_(probs_on_device, top_p_device, top_p_init_pivots_device); RECORD_EVENT(trace_recorder_, request_ids, "finish renormalization by top p"); return renormed_probs_on_device; @@ -500,6 +518,9 @@ class GPUSampler : public SamplerObj { << "GPU sampler requires the top_p values for each prob distribution are the same."; } } + for (int i = 0; i < num_probs; ++i) { + p_top_p[i] = std::max(p_top_p[i], eps_); + } return need_top_p; } @@ -665,6 +686,7 @@ class GPUSampler : public SamplerObj { NDArray uniform_samples_host_; NDArray sample_indices_host_; NDArray top_p_host_; + NDArray top_p_init_pivots_host_; NDArray top_prob_offsets_host_; NDArray draft_tokens_host_; NDArray token_tree_first_child_host_; @@ -678,6 +700,7 @@ class GPUSampler : public SamplerObj { NDArray uniform_samples_device_; NDArray sample_indices_device_; NDArray top_p_device_; + NDArray top_p_init_pivots_device_; NDArray top_prob_offsets_device_; NDArray draft_tokens_device_; NDArray token_tree_first_child_device_; @@ -691,6 +714,7 @@ class GPUSampler : public SamplerObj { // The device stream for copying auxiliary data structure to GPU. TVMStreamHandle copy_stream_ = nullptr; const float eps_ = 1e-5; + const int num_top_p_cutoff_pivots_ = 3; }; Sampler Sampler::CreateGPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft, diff --git a/python/mlc_llm/compiler_pass/attach_sampler.py b/python/mlc_llm/compiler_pass/attach_sampler.py index 46dc40c106..5bf62257a1 100644 --- a/python/mlc_llm/compiler_pass/attach_sampler.py +++ b/python/mlc_llm/compiler_pass/attach_sampler.py @@ -7,7 +7,8 @@ from tvm.relax.frontend import nn from tvm.script import tir as T -from ..op.batch_spec_verify import batch_spec_verify +from mlc_llm.op.batch_spec_verify import batch_spec_verify +from mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm @tvm.transform.module_pass(opt_level=0, name="AttachGPUSamplingFunc") @@ -49,7 +50,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR _attach_sample_with_top_p(bb, vocab_size), _attach_take_probs_func(bb, vocab_size), _attach_batch_verifier(bb, vocab_size), - _attach_renormalize_by_top_p(bb, vocab_size), + _attach_renormalize_by_top_p(bb, vocab_size, self.target), ] ] @@ -227,41 +228,36 @@ def _attach_sample_with_top_p( # pylint: disable=too-many-locals return gv -def _attach_renormalize_by_top_p(bb: relax.BlockBuilder, vocab_size: tir.PrimExpr): +def _attach_renormalize_by_top_p( + bb: relax.BlockBuilder, vocab_size: tir.PrimExpr, target: tvm.target.Target +): batch_size = tir.Var("batch_size", "int64") + num_pivots = 3 probs = relax.Var("probs", relax.TensorStructInfo((batch_size, vocab_size), "float32")) - sorted_probs = relax.Var( - "sorted_probs", relax.TensorStructInfo((batch_size, vocab_size), "float32") - ) top_p = relax.Var("top_p", relax.TensorStructInfo((batch_size,), "float32")) - with bb.function("renormalize_by_top_p", [probs, sorted_probs, top_p]): + init_pivots = relax.Var( + "init_pivots", relax.TensorStructInfo((batch_size, num_pivots), "float32") + ) + with bb.function("renormalize_by_top_p", [probs, top_p, init_pivots]): with bb.dataflow(): - probs_tensor = nn.wrap_nested(probs, name="probs") - sorted_probs_tensor = nn.wrap_nested(sorted_probs, name="sorted_probs") - top_p_shape = relax.ShapeExpr([batch_size, 1]) - top_p_tensor = nn.wrap_nested( - relax.call_pure_packed( - "vm.builtin.reshape", - top_p, - top_p_shape, - sinfo_args=relax.TensorStructInfo(top_p_shape, "float32"), - ), - name="sample_indices", - ) - top_k_tensor = nn.tensor_ir_op( - full, - name_hint="full", - args=[vocab_size], - out=nn.Tensor.placeholder( - [batch_size, 1], - "int32", - ), + cutoff_output = bb.emit( + relax.call_tir( + bb.add_func(top_p_pivot(num_pivots, target), "top_p_pivot_cutoff"), + args=[probs, top_p, init_pivots], + out_sinfo=[top_p.struct_info, top_p.struct_info], # pylint: disable=no-member + ) ) - renormalized_probs = nn.renormalize_top_p_top_k_prob( - probs_tensor, sorted_probs_tensor, top_p_tensor, top_k_tensor + final_pivot = cutoff_output[0] + renorm_sum = cutoff_output[1] + renormalized_probs = bb.emit( + relax.call_tir( + bb.add_func(top_p_renorm(target), "top_p_renorm_after_cutoff"), + args=[probs, final_pivot, renorm_sum], + out_sinfo=probs.struct_info, # pylint: disable=no-member + ) ) - bb.emit_output(renormalized_probs._expr) # pylint: disable=protected-access - gv = bb.emit_func_output(renormalized_probs._expr) # pylint: disable=protected-access + bb.emit_output(renormalized_probs) + gv = bb.emit_func_output(renormalized_probs) return gv diff --git a/python/mlc_llm/compiler_pass/rewrite_softmax.py b/python/mlc_llm/compiler_pass/rewrite_softmax.py index df879b37ec..47a5a168d7 100644 --- a/python/mlc_llm/compiler_pass/rewrite_softmax.py +++ b/python/mlc_llm/compiler_pass/rewrite_softmax.py @@ -79,6 +79,15 @@ def visit_call_(self, call: relax.Call) -> Expr: # pylint: disable=arguments-re def _get_lse_and_softmax_func( # pylint: disable=too-many-locals,too-many-statements target: tvm.target.Target, chunk_size: int ): + # NOTE: A quick note on the softmax implementation. + # We once tried to multiply every element by log2e which can be computed + # potentially more efficiently on hardware. + # However, when the input values are large, multiplying by the factor of log2e + # causes numerical issue in float32 dtype. + # This leads to the softmax output not summing up to 1. + # For numerical stability, we removed the log2e factor and switched back + # to the standard log/exp computation. + # pylint: disable=invalid-name @T.prim_func def chunk_lse(var_A: T.handle, var_chunked_lse: T.handle): # pylint: disable=too-many-locals diff --git a/python/mlc_llm/op/top_p_pivot.py b/python/mlc_llm/op/top_p_pivot.py index 9c97959bff..b9565a83c9 100644 --- a/python/mlc_llm/op/top_p_pivot.py +++ b/python/mlc_llm/op/top_p_pivot.py @@ -3,12 +3,14 @@ import tvm from tvm.script import tir as T +from mlc_llm.support.max_thread_check import get_max_num_threads_per_block + # mypy: disable-error-code="attr-defined,valid-type,name-defined" # pylint: disable=too-many-locals,invalid-name,too-many-arguments,unnecessary-lambda # pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches -def top_p_pivot(pN): +def top_p_pivot(pN, target: tvm.target.Target): """Top-p pivot function. This function finds the pivot to cut-off top-p percentile. A valide pivot should satisfy the following conditions: @@ -23,7 +25,7 @@ def top_p_pivot(pN): prob: The probability vector - top_p_global: + top_p_arr: The top-p threshold init_pivots: @@ -31,11 +33,18 @@ def top_p_pivot(pN): final_pivot: The final pivot to cut-off top-p percentile + + final_lsum: + The final sum of the values after top-p filtering. """ TX = 1024 K = 32 eps_LR = 1e-7 + max_num_threads_per_block = get_max_num_threads_per_block(target) + if max_num_threads_per_block < TX: + TX = max_num_threads_per_block + def _var(dtype="int32"): return T.alloc_buffer((1,), dtype, scope="local") @@ -46,7 +55,7 @@ def valid(lsum, lmin, cmin, top_p): @T.prim_func(private=True) def _func( var_prob: T.handle, - top_p_global: T.buffer([1], dtype="float32"), + var_top_p_arr: T.handle, var_init_pivots: T.handle, var_final_pivot: T.handle, var_final_lsum: T.handle, @@ -55,7 +64,8 @@ def _func( B = T.int32() N = T.int32() prob = T.match_buffer(var_prob, (B, N,), "float32") - init_pivots = T.match_buffer(var_init_pivots, (pN,), "float32") + top_p_arr = T.match_buffer(var_top_p_arr, (B,), dtype="float32") + init_pivots = T.match_buffer(var_init_pivots, (B, pN), "float32") final_pivot = T.match_buffer(var_final_pivot, (B,), "float32") final_lsum = T.match_buffer(var_final_lsum, (B,), "float32") @@ -92,7 +102,7 @@ def _func( with T.block("CTA"): b, tx = T.axis.remap("SS", [_bx, _tx]) - top_p[0] = top_p_global[0] + top_p[0] = top_p_arr[b] if tx == 0: # leader thread initializes L, R @@ -105,8 +115,14 @@ def _func( R_local[0] = R[0] for i in T.unroll(0, pN): # pivots are in descending order - pivot[i] = init_pivots[i] + pivot[i] = init_pivots[b, i] find_pivot_local[0] = False + if L_local[0] - R_local[0] <= eps_LR: + # When the initial value is too small, set the result directly. + if tx == 0: + final_lsum[b] = 1.0 + final_pivot[b] = 0.0 + find_pivot_local[0] = True while T.tvm_thread_invariant( L_local[0] - R_local[0] > eps_LR @@ -118,7 +134,7 @@ def _func( ### get lsum, lmin, total_sum for pidx in T.unroll(0, pN): lsum[pidx] = 0.0 - lmin[pidx] = 1.0 + lmin[pidx] = T.max_value("float32") cmin[pidx] = 0 total_sum[0] = 0.0 it[0] = 0 @@ -226,6 +242,7 @@ def _func( final_lsum[b] = lsum[pidx] elif lsum[pidx] - lmin[pidx] * cmin[pidx] >= top_p[0]: R[0] = pivot[pidx] + final_lsum[b] = lsum[pidx] elif lsum[pidx] < top_p[0]: L[0] = pivot[pidx] it[0] += 1 @@ -243,13 +260,15 @@ def _func( if tx == 0: # leader thread writes back the pivot if T.Not(find_pivot_local[0]): - final_pivot[b] = -1e5 + final_pivot[b] = R_local[0] + if R_local[0] == eps_LR: + final_lsum[b] = lsum[pN - 1] # fmt: on return _func -def top_p_renorm(): +def top_p_renorm(target: tvm.target.Target = None): """Top-p renormalization function. This function renormalizes the probability vector. Given the pivot, the probability vector is renormalized as follows: @@ -273,6 +292,11 @@ def top_p_renorm(): TX = 1024 CTA_COUNT = 512 + if target: + max_num_threads_per_block = get_max_num_threads_per_block(target) + if max_num_threads_per_block < TX: + TX = max_num_threads_per_block + def _var(dtype="int32"): return T.alloc_buffer((1,), dtype, scope="local")